# # # delete "paramiko/auth_transport.py" # # add_file "paramiko/auth_handler.py" # content [c347c3c5c333cee0ae61a16deb1b44dcca43ec1d] # # add_file "paramiko/compress.py" # content [e4a132074bc6be534e3b02de956e2f33e5d0cf9d] # # add_file "paramiko/hostkeys.py" # content [6f0ab19d1e3fca58cad0427f7a3351f25638dd6d] # # patch "paramiko/__init__.py" # from [5ac689af375a52cd8667f6f2c00242d9e321b78f] # to [6b9cd0faf4fe8fa477eb5ed62da535302915f870] # # patch "paramiko/agent.py" # from [74990e45a83b4f796cacfaa3fcb061c99c6e0952] # to [44065bf0813f3dbdf87dc35ad9c7e17ebcf58b1c] # # patch "paramiko/ber.py" # from [b78e828bffc7f7dc34ff6528119b1478c64032d1] # to [b55cb95a4951718a0294c5c5000488049fd43635] # # patch "paramiko/channel.py" # from [0596f26cb50730f9de7c3ee14e36781cd41bf4a0] # to [c5f45a7e75ce23f71733a51b848126f964ca6460] # # patch "paramiko/common.py" # from [1518442b23a261749bc0eb223d7ebc2b54eaf3f1] # to [734aa55bfaabc6b0902db4d432ccfb0dc22666b3] # # patch "paramiko/dsskey.py" # from [92d4afa51af1c4a11fbdc52f1ac49516bd17a68f] # to [96eafc3818ff93ec9b49432ad77b490baabd76e3] # # patch "paramiko/file.py" # from [036024e929348a5bce74525520e0c48465ee79a4] # to [d2478f2bd6e9643a4a54ae8ee0573400d5e96e11] # # patch "paramiko/kex_gex.py" # from [f6e75b3a0fe6e7f9cb7f5ef1a1ee2fe85c72a8c8] # to [0e065c7cf5e983857f321004230c2992fb5b3f90] # # patch "paramiko/kex_group1.py" # from [e105620e8553e1920a2cad121992a11a1335cc68] # to [927b0fcbf8ab634cba6ff363bea219e888012df9] # # patch "paramiko/logging22.py" # from [d2073500a0e5fcebcd51a1abaf3d64363d7fc8bd] # to [a5f736263bc246fd839855952a9c7d1b6fbbccdc] # # patch "paramiko/message.py" # from [677e83f4c16648442cf78ae84464bdd60e528033] # to [5d9ed3d6fb92eb3e0968b9be7b114a7d7771060e] # # patch "paramiko/packet.py" # from [92dd579d5f50b35110c71e86311bbf172c6804f1] # to [370e998149a0518964582af9357936c51c84c685] # # patch "paramiko/pipe.py" # from [78054bd0d865e230464bd781abe2617c2f2c7136] # to [07a33dcad6057372ed4e9be5743b81938753f390] # # patch "paramiko/pkey.py" # from [1c7280154144125bf970c9fabf4eead4fdddf0e3] # to [e6feb29295376dbe3a2a58ef37f1d60be4f0eb99] # # patch "paramiko/primes.py" # from [aa1fc3675f3e154678e0c961dda8a3fee7bd9a82] # to [1b1bd9084fef5f567bad5a4de971e4a6e2be5d86] # # patch "paramiko/rsakey.py" # from [fdb464f41b215adbeee620603ccc1da76cbabaad] # to [747c8f53df69713d965fdf1e92ad3e75bc73c606] # # patch "paramiko/server.py" # from [684ff520b01acecaa50cba1e770539c889680db8] # to [6027d5d22d1f043130e894e82251f2cd0d4b64be] # # patch "paramiko/sftp.py" # from [6c58bfe7730954a3ddf680b5b463e8c0fff295c7] # to [f1b11aa12a263d400f803a759e3f77f4bdbe47d7] # # patch "paramiko/sftp_attr.py" # from [835dbd7baed2d3403bf555de6c2ea833a246c6f3] # to [b78aaf2565868de77fde605dbabecce8091e2230] # # patch "paramiko/sftp_client.py" # from [5bd2a4fe7cfc525df863b12e4dd585afd4d34bb9] # to [bcd91938416220067071abcec12e5b54c1278208] # # patch "paramiko/sftp_file.py" # from [b6da074f26d4d2e8f891d0eb823d8ac9e9d50121] # to [1670b714a0025b13eea2e9d40d58dc46bdc508fe] # # patch "paramiko/sftp_handle.py" # from [839b45b7b379f6e726ff0c1017c4b2f0ea285344] # to [f41290936813308293937538626ab33350e01f82] # # patch "paramiko/sftp_server.py" # from [fc13033c75ae152a5c3c5758ebb9ecc7f2b5edb8] # to [73185d628c5417fcc59c965125f1aa0426d7e471] # # patch "paramiko/sftp_si.py" # from [41df6e4b221b6e8a79aa43bc4e52a59e170e75ac] # to [fc14cccbee4d37e67230139e4e1d45dcb13ac358] # # patch "paramiko/ssh_exception.py" # from [4a74db486cf4bf8d5480cf850d9caa85efe8d79c] # to [c8757ab9e6c8ed88b69199279d747f64d31d2897] # # patch "paramiko/transport.py" # from [28f37067e15b678b878637b19ea550d116038df0] # to [eddb4e639b1ff6966ea2dea4d19102b102d8ebb2] # # patch "paramiko/util.py" # from [368b5dc43be94fdc356a45c5aab43065771e76a4] # to [4b7ebbb619b6ef6e46b4287d1b78d8a09b87ba99] # ============================================================ --- paramiko/auth_handler.py c347c3c5c333cee0ae61a16deb1b44dcca43ec1d +++ paramiko/auth_handler.py c347c3c5c333cee0ae61a16deb1b44dcca43ec1d @@ -0,0 +1,413 @@ +# Copyright (C) 2003-2006 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{AuthHandler} +""" + +import threading +import weakref + +# this helps freezing utils +import encodings.utf_8 + +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException, BadAuthenticationType, PartialAuthentication +from paramiko.server import InteractiveQuery + + +class AuthHandler (object): + """ + Internal class to handle the mechanics of authentication. + """ + + def __init__(self, transport): + self.transport = weakref.proxy(transport) + self.username = None + self.authenticated = False + self.auth_event = None + self.auth_method = '' + self.password = None + self.private_key = None + self.interactive_handler = None + self.submethods = None + # for server mode: + self.auth_username = None + self.auth_fail_count = 0 + + def is_authenticated(self): + return self.authenticated + + def get_username(self): + if self.transport.server_mode: + return self.auth_username + else: + return self.username + + def auth_none(self, username, event): + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'none' + self.username = username + self._request_auth() + finally: + self.transport.lock.release() + + def auth_publickey(self, username, key, event): + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'publickey' + self.username = username + self.private_key = key + self._request_auth() + finally: + self.transport.lock.release() + + def auth_password(self, username, password, event): + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'password' + self.username = username + self.password = password + self._request_auth() + finally: + self.transport.lock.release() + + def auth_interactive(self, username, handler, event, submethods=''): + """ + response_list = handler(title, instructions, prompt_list) + """ + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'keyboard-interactive' + self.username = username + self.interactive_handler = handler + self.submethods = submethods + self._request_auth() + finally: + self.transport.lock.release() + + def abort(self): + if self.auth_event is not None: + self.auth_event.set() + + + ### internals... + + + def _request_auth(self): + m = Message() + m.add_byte(chr(MSG_SERVICE_REQUEST)) + m.add_string('ssh-userauth') + self.transport._send_message(m) + + def _disconnect_service_not_available(self): + m = Message() + m.add_byte(chr(MSG_DISCONNECT)) + m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) + m.add_string('Service not available') + m.add_string('en') + self.transport._send_message(m) + self.transport.close() + + def _disconnect_no_more_auth(self): + m = Message() + m.add_byte(chr(MSG_DISCONNECT)) + m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) + m.add_string('No more auth methods available') + m.add_string('en') + self.transport._send_message(m) + self.transport.close() + + def _get_session_blob(self, key, service, username): + m = Message() + m.add_string(self.transport.session_id) + m.add_byte(chr(MSG_USERAUTH_REQUEST)) + m.add_string(username) + m.add_string(service) + m.add_string('publickey') + m.add_boolean(1) + m.add_string(key.get_name()) + m.add_string(str(key)) + return str(m) + + def wait_for_response(self, event): + while True: + event.wait(0.1) + if not self.transport.is_active(): + e = self.transport.get_exception() + if e is None: + e = SSHException('Authentication failed.') + raise e + if event.isSet(): + break + if not self.is_authenticated(): + e = self.transport.get_exception() + if e is None: + e = SSHException('Authentication failed.') + # this is horrible. python Exception isn't yet descended from + # object, so type(e) won't work. :( + if issubclass(e.__class__, PartialAuthentication): + return e.allowed_types + raise e + return [] + + def _parse_service_request(self, m): + service = m.get_string() + if self.transport.server_mode and (service == 'ssh-userauth'): + # accepted + m = Message() + m.add_byte(chr(MSG_SERVICE_ACCEPT)) + m.add_string(service) + self.transport._send_message(m) + return + # dunno this one + self._disconnect_service_not_available() + + def _parse_service_accept(self, m): + service = m.get_string() + if service == 'ssh-userauth': + self.transport._log(DEBUG, 'userauth is OK') + m = Message() + m.add_byte(chr(MSG_USERAUTH_REQUEST)) + m.add_string(self.username) + m.add_string('ssh-connection') + m.add_string(self.auth_method) + if self.auth_method == 'password': + m.add_boolean(False) + m.add_string(self.password.encode('UTF-8')) + elif self.auth_method == 'publickey': + m.add_boolean(True) + m.add_string(self.private_key.get_name()) + m.add_string(str(self.private_key)) + blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username) + sig = self.private_key.sign_ssh_data(self.transport.randpool, blob) + m.add_string(str(sig)) + elif self.auth_method == 'keyboard-interactive': + m.add_string('') + m.add_string(self.submethods) + elif self.auth_method == 'none': + pass + else: + raise SSHException('Unknown auth method "%s"' % self.auth_method) + self.transport._send_message(m) + else: + self.transport._log(DEBUG, 'Service request "%s" accepted (?)' % service) + + def _send_auth_result(self, username, method, result): + # okay, send result + m = Message() + if result == AUTH_SUCCESSFUL: + self.transport._log(INFO, 'Auth granted (%s).' % method) + m.add_byte(chr(MSG_USERAUTH_SUCCESS)) + self.authenticated = True + else: + self.transport._log(INFO, 'Auth rejected (%s).' % method) + m.add_byte(chr(MSG_USERAUTH_FAILURE)) + m.add_string(self.transport.server_object.get_allowed_auths(username)) + if result == AUTH_PARTIALLY_SUCCESSFUL: + m.add_boolean(1) + else: + m.add_boolean(0) + self.auth_fail_count += 1 + self.transport._send_message(m) + if self.auth_fail_count >= 10: + self._disconnect_no_more_auth() + if result == AUTH_SUCCESSFUL: + self.transport._auth_trigger() + + def _interactive_query(self, q): + # make interactive query instead of response + m = Message() + m.add_byte(chr(MSG_USERAUTH_INFO_REQUEST)) + m.add_string(q.name) + m.add_string(q.instructions) + m.add_string('') + m.add_int(len(q.prompts)) + for p in q.prompts: + m.add_string(p[0]) + m.add_boolean(p[1]) + self.transport._send_message(m) + + def _parse_userauth_request(self, m): + if not self.transport.server_mode: + # er, uh... what? + m = Message() + m.add_byte(chr(MSG_USERAUTH_FAILURE)) + m.add_string('none') + m.add_boolean(0) + self.transport._send_message(m) + return + if self.authenticated: + # ignore + return + username = m.get_string() + service = m.get_string() + method = m.get_string() + self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) + if service != 'ssh-connection': + self._disconnect_service_not_available() + return + if (self.auth_username is not None) and (self.auth_username != username): + self.transport._log(WARNING, 'Auth rejected because the client attempted to change username in mid-flight') + self._disconnect_no_more_auth() + return + self.auth_username = username + + if method == 'none': + result = self.transport.server_object.check_auth_none(username) + elif method == 'password': + changereq = m.get_boolean() + password = m.get_string().decode('UTF-8', 'replace') + if changereq: + # always treated as failure, since we don't support changing passwords, but collect + # the list of valid auth types from the callback anyway + self.transport._log(DEBUG, 'Auth request to change passwords (rejected)') + newpassword = m.get_string().decode('UTF-8', 'replace') + result = AUTH_FAILED + else: + result = self.transport.server_object.check_auth_password(username, password) + elif method == 'publickey': + sig_attached = m.get_boolean() + keytype = m.get_string() + keyblob = m.get_string() + try: + key = self.transport._key_info[keytype](Message(keyblob)) + except SSHException, e: + self.transport._log(INFO, 'Auth rejected: public key: %s' % str(e)) + key = None + except: + self.transport._log(INFO, 'Auth rejected: unsupported or mangled public key') + key = None + if key is None: + self._disconnect_no_more_auth() + return + # first check if this key is okay... if not, we can skip the verify + result = self.transport.server_object.check_auth_publickey(username, key) + if result != AUTH_FAILED: + # key is okay, verify it + if not sig_attached: + # client wants to know if this key is acceptable, before it + # signs anything... send special "ok" message + m = Message() + m.add_byte(chr(MSG_USERAUTH_PK_OK)) + m.add_string(keytype) + m.add_string(keyblob) + self.transport._send_message(m) + return + sig = Message(m.get_string()) + blob = self._get_session_blob(key, service, username) + if not key.verify_ssh_sig(blob, sig): + self.transport._log(INFO, 'Auth rejected: invalid signature') + result = AUTH_FAILED + elif method == 'keyboard-interactive': + lang = m.get_string() + submethods = m.get_string() + result = self.transport.server_object.check_auth_interactive(username, submethods) + if isinstance(result, InteractiveQuery): + # make interactive query instead of response + self._interactive_query(result) + return + else: + result = self.transport.server_object.check_auth_none(username) + # okay, send result + self._send_auth_result(username, method, result) + + def _parse_userauth_success(self, m): + self.transport._log(INFO, 'Authentication (%s) successful!' % self.auth_method) + self.authenticated = True + self.transport._auth_trigger() + if self.auth_event != None: + self.auth_event.set() + + def _parse_userauth_failure(self, m): + authlist = m.get_list() + partial = m.get_boolean() + if partial: + self.transport._log(INFO, 'Authentication continues...') + self.transport._log(DEBUG, 'Methods: ' + str(authlist)) + self.transport.saved_exception = PartialAuthentication(authlist) + elif self.auth_method not in authlist: + self.transport._log(INFO, 'Authentication type (%s) not permitted.' % self.auth_method) + self.transport._log(DEBUG, 'Allowed methods: ' + str(authlist)) + self.transport.saved_exception = BadAuthenticationType('Bad authentication type', authlist) + else: + self.transport._log(INFO, 'Authentication (%s) failed.' % self.auth_method) + self.authenticated = False + self.username = None + if self.auth_event != None: + self.auth_event.set() + + def _parse_userauth_banner(self, m): + banner = m.get_string() + lang = m.get_string() + self.transport._log(INFO, 'Auth banner: ' + banner) + # who cares. + + def _parse_userauth_info_request(self, m): + if self.auth_method != 'keyboard-interactive': + raise SSHException('Illegal info request from server') + title = m.get_string() + instructions = m.get_string() + m.get_string() # lang + prompts = m.get_int() + prompt_list = [] + for i in range(prompts): + prompt_list.append((m.get_string(), m.get_boolean())) + response_list = self.interactive_handler(title, instructions, prompt_list) + + m = Message() + m.add_byte(chr(MSG_USERAUTH_INFO_RESPONSE)) + m.add_int(len(response_list)) + for r in response_list: + m.add_string(r) + self.transport._send_message(m) + + def _parse_userauth_info_response(self, m): + if not self.transport.server_mode: + raise SSHException('Illegal info response from server') + n = m.get_int() + responses = [] + for i in range(n): + responses.append(m.get_string()) + result = self.transport.server_object.check_auth_interactive_response(responses) + if isinstance(type(result), InteractiveQuery): + # make interactive query instead of response + self._interactive_query(result) + return + self._send_auth_result(self.auth_username, 'keyboard-interactive', result) + + + _handler_table = { + MSG_SERVICE_REQUEST: _parse_service_request, + MSG_SERVICE_ACCEPT: _parse_service_accept, + MSG_USERAUTH_REQUEST: _parse_userauth_request, + MSG_USERAUTH_SUCCESS: _parse_userauth_success, + MSG_USERAUTH_FAILURE: _parse_userauth_failure, + MSG_USERAUTH_BANNER: _parse_userauth_banner, + MSG_USERAUTH_INFO_REQUEST: _parse_userauth_info_request, + MSG_USERAUTH_INFO_RESPONSE: _parse_userauth_info_response, + } + + ============================================================ --- paramiko/compress.py e4a132074bc6be534e3b02de956e2f33e5d0cf9d +++ paramiko/compress.py e4a132074bc6be534e3b02de956e2f33e5d0cf9d @@ -0,0 +1,39 @@ +# Copyright (C) 2003-2006 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Compression implementations for a Transport. +""" + +import zlib + + +class ZlibCompressor (object): + def __init__(self): + self.z = zlib.compressobj(9) + + def __call__(self, data): + return self.z.compress(data) + self.z.flush(zlib.Z_FULL_FLUSH) + + +class ZlibDecompressor (object): + def __init__(self): + self.z = zlib.decompressobj() + + def __call__(self, data): + return self.z.decompress(data) ============================================================ --- paramiko/hostkeys.py 6f0ab19d1e3fca58cad0427f7a3351f25638dd6d +++ paramiko/hostkeys.py 6f0ab19d1e3fca58cad0427f7a3351f25638dd6d @@ -0,0 +1,188 @@ +# Copyright (C) 2006 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{HostKeys} +""" + +import base64 +from Crypto.Hash import SHA, HMAC +import UserDict + +from paramiko.common import * +from paramiko.dsskey import DSSKey +from paramiko.rsakey import RSAKey + + +class HostKeys (UserDict.DictMixin): + """ + Representation of an openssh-style "known hosts" file. Host keys can be + read from one or more files, and then individual hosts can be looked up to + verify server keys during SSH negotiation. + + A HostKeys object can be treated like a dict; any dict lookup is equivalent + to calling L{lookup}. + + @since: 1.5.3 + """ + + def __init__(self, filename=None): + """ + Create a new HostKeys object, optionally loading keys from an openssh + style host-key file. + + @param filename: filename to load host keys from, or C{None} + @type filename: str + """ + # hostname -> keytype -> PKey + self._keys = {} + self.contains_hashes = False + if filename is not None: + self.load(filename) + + def add(self, hostname, keytype, key): + """ + Add a host key entry to the table. Any existing entry for a + C{(hostname, keytype)} pair will be replaced. + + @param hostname: + @type hostname: str + @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"}) + @type keytype: str + @param key: the key to add + @type key: L{PKey} + """ + if not hostname in self._keys: + self._keys[hostname] = {} + if hostname.startswith('|1|'): + self.contains_hashes = True + self._keys[hostname][keytype] = key + + def load(self, filename): + """ + Read a file of known SSH host keys, in the format used by openssh. + This type of file unfortunately doesn't exist on Windows, but on + posix, it will usually be stored in + C{os.path.expanduser("~/.ssh/known_hosts")}. + + @param filename: name of the file to read host keys from + @type filename: str + """ + f = file(filename, 'r') + for line in f: + line = line.strip() + if (len(line) == 0) or (line[0] == '#'): + continue + keylist = line.split(' ') + if len(keylist) != 3: + # don't understand this line + continue + hostlist, keytype, key = keylist + for host in hostlist.split(','): + if keytype == 'ssh-rsa': + self.add(host, keytype, RSAKey(data=base64.decodestring(key))) + elif keytype == 'ssh-dss': + self.add(host, keytype, DSSKey(data=base64.decodestring(key))) + f.close() + + def lookup(self, hostname): + """ + Find a hostkey entry for a given hostname or IP. If no entry is found, + C{None} is returned. Otherwise a dictionary of keytype to key is + returned. + + @param hostname: the hostname to lookup + @type hostname: str + @return: keys associated with this host (or C{None}) + @rtype: dict(str, L{PKey}) + """ + if hostname in self._keys: + return self._keys[hostname] + if not self.contains_hashes: + return None + for h in self._keys.keys(): + if h.startswith('|1|'): + hmac = self.hash_host(hostname, h) + if hmac == h: + return self._keys[h] + return None + + def check(self, hostname, key): + """ + Return True if the given key is associated with the given hostname + in this dictionary. + + @param hostname: hostname (or IP) of the SSH server + @type hostname: str + @param key: the key to check + @type key: L{PKey} + @return: C{True} if the key is associated with the hostname; C{False} + if not + @rtype: bool + """ + k = self.lookup(hostname) + if k is None: + return False + host_key = k.get(key.get_name(), None) + if host_key is None: + return False + return str(host_key) == str(key) + + def clear(self): + """ + Remove all host keys from the dictionary. + """ + self._keys = {} + self.contains_hashes = False + + def __getitem__(self, key): + ret = self.lookup(key) + if ret is None: + raise KeyError(key) + return ret + + def keys(self): + return self._keys.keys() + + def values(self): + return self._keys.values(); + + def hash_host(hostname, salt=None): + """ + Return a "hashed" form of the hostname, as used by openssh when storing + hashed hostnames in the known_hosts file. + + @param hostname: the hostname to hash + @type hostname: str + @param salt: optional salt to use when hashing (must be 20 bytes long) + @type salt: str + @return: the hashed hostname + @rtype: str + """ + if salt is None: + salt = randpool.get_bytes(SHA.digest_size) + else: + if salt.startswith('|1|'): + salt = salt.split('|')[2] + salt = base64.decodestring(salt) + assert len(salt) == SHA.digest_size + hmac = HMAC.HMAC(salt, hostname, SHA).digest() + hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac)) + return hostkey.replace('\n', '') + hash_host = staticmethod(hash_host) + ============================================================ --- paramiko/__init__.py 5ac689af375a52cd8667f6f2c00242d9e321b78f +++ paramiko/__init__.py 6b9cd0faf4fe8fa477eb5ed62da535302915f870 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -27,8 +27,8 @@ services across an encrypted tunnel. (This is how C{sftp} works, for example.) To use this package, pass a socket (or socket-like object) to a L{Transport}, -and use L{start_server } or -L{start_client } to negoatite +and use L{start_server } or +L{start_client } to negoatite with the remote host as either a server or client. As a client, you are responsible for authenticating using a password or private key, and checking the server's host key. I{(Key signature and verification is done by paramiko, @@ -46,7 +46,7 @@ Website: U{http://www.lag.net/paramiko/} address@hidden: 1.4 (oddish) address@hidden: 1.5.4 (tentacool) @author: Robey Pointer @contact: address@hidden @license: GNU Lesser General Public License (LGPL) @@ -59,37 +59,41 @@ __author__ = "Robey Pointer " -__date__ = "18 Jul 2005" -__version__ = "1.4 (oddish)" +__date__ = "11 Mar 2005" +__version__ = "1.5.4 (tentacool)" +__version_info__ = (1, 5, 4) __license__ = "GNU Lesser General Public License (LGPL)" -import transport, auth_transport, channel, rsakey, dsskey, message -import ssh_exception, file, packet, agent, server, util -import sftp_client, sftp_attr, sftp_handle, sftp_server, sftp_si +from transport import randpool, SecurityOptions, Transport +from auth_handler import AuthHandler +from channel import Channel, ChannelFile +from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType +from server import ServerInterface, SubsystemHandler, InteractiveQuery +from rsakey import RSAKey +from dsskey import DSSKey +from sftp import SFTPError, BaseSFTP +from sftp_client import SFTP, SFTPClient +from sftp_server import SFTPServer +from sftp_attr import SFTPAttributes +from sftp_handle import SFTPHandle +from sftp_si import SFTPServerInterface +from sftp_file import SFTPFile +from message import Message +from packet import Packetizer +from file import BufferedFile +from agent import Agent, AgentKey +from pkey import PKey +from hostkeys import HostKeys -randpool = transport.randpool -Transport = auth_transport.Transport -Channel = channel.Channel -RSAKey = rsakey.RSAKey -DSSKey = dsskey.DSSKey -SSHException = ssh_exception.SSHException -Message = message.Message -PasswordRequiredException = ssh_exception.PasswordRequiredException -BadAuthenticationType = ssh_exception.BadAuthenticationType -SFTP = sftp_client.SFTP -SFTPClient = sftp_client.SFTPClient -SFTPServer = sftp_server.SFTPServer -from sftp import SFTPError -SFTPAttributes = sftp_attr.SFTPAttributes -SFTPHandle = sftp_handle.SFTPHandle -SFTPServerInterface = sftp_si.SFTPServerInterface -ServerInterface = server.ServerInterface -SubsystemHandler = server.SubsystemHandler -SecurityOptions = transport.SecurityOptions -BufferedFile = file.BufferedFile -Packetizer = packet.Packetizer -Agent = agent.Agent +# fix module names for epydoc +for x in [Transport, SecurityOptions, Channel, SFTPServer, SSHException, \ + PasswordRequiredException, BadAuthenticationType, ChannelFile, \ + SubsystemHandler, AuthHandler, RSAKey, DSSKey, SFTPError, \ + SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes, \ + SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey, \ + PKey, BaseSFTP, SFTPFile, ServerInterface, HostKeys]: + x.__module__ = 'paramiko' from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \ OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \ @@ -110,6 +114,7 @@ 'PasswordRequiredException', 'BadAuthenticationType', 'SFTP', + 'SFTPFile', 'SFTPHandle', 'SFTPClient', 'SFTPServer', @@ -119,22 +124,6 @@ 'ServerInterface', 'BufferedFile', 'Agent', - 'transport', - 'auth_transport', - 'channel', - 'rsakey', - 'dsskey', - 'pkey', - 'message', - 'ssh_exception', - 'sftp', - 'sftp_client', - 'sftp_server', - 'sftp_attr', - 'sftp_file', - 'sftp_si', - 'sftp_handle', - 'server', - 'file', - 'agent', + 'AgentKey', + 'HostKeys', 'util' ] ============================================================ --- paramiko/agent.py 74990e45a83b4f796cacfaa3fcb061c99c6e0952 +++ paramiko/agent.py 44065bf0813f3dbdf87dc35ad9c7e17ebcf58b1c @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 John Rochester +# Copyright (C) 2003-2006 John Rochester # # This file is part of paramiko. # @@ -20,15 +20,20 @@ SSH Agent interface for Unix clients. """ -import os, socket, struct +import os +import socket +import struct +import sys -from ssh_exception import SSHException -from message import Message -from pkey import PKey +from paramiko.ssh_exception import SSHException +from paramiko.message import Message +from paramiko.pkey import PKey + SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \ SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15) + class Agent: """ Client interface for using private keys from an SSH agent running on the @@ -50,12 +55,12 @@ @raise SSHException: if an SSH agent is found, but speaks an incompatible protocol """ - if 'SSH_AUTH_SOCK' in os.environ: + if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) conn.connect(os.environ['SSH_AUTH_SOCK']) self.conn = conn - type, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) - if type != SSH2_AGENT_IDENTITIES_ANSWER: + ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) + if ptype != SSH2_AGENT_IDENTITIES_ANSWER: raise SSHException('could not get keys from ssh-agent') keys = [] for i in range(result.get_int()): @@ -127,7 +132,7 @@ msg.add_string(self.blob) msg.add_string(data) msg.add_int(0) - type, result = self.agent._send_message(msg) - if type != SSH2_AGENT_SIGN_RESPONSE: + ptype, result = self.agent._send_message(msg) + if ptype != SSH2_AGENT_SIGN_RESPONSE: raise SSHException('key cannot be used for signing') return result.get_string() ============================================================ --- paramiko/ber.py b78e828bffc7f7dc34ff6528119b1478c64032d1 +++ paramiko/ber.py b55cb95a4951718a0294c5c5000488049fd43635 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -18,12 +16,14 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -import struct, util +import util + class BERException (Exception): pass + class BER(object): """ Robey's tiny little attempt at a BER decoder. @@ -91,8 +91,9 @@ while True: x = b.decode_next() if x is None: - return out + break out.append(x) + return out decode_sequence = staticmethod(decode_sequence) def encode_tlv(self, ident, val): ============================================================ --- paramiko/channel.py 0596f26cb50730f9de7c3ee14e36781cd41bf4a0 +++ paramiko/channel.py c5f45a7e75ce23f71733a51b848126f964ca6460 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -20,14 +20,18 @@ Abstraction for an SSH2 channel. """ -import sys, time, threading, socket, os +import sys +import time +import threading +import socket +import os -from common import * -import util -from message import Message -from ssh_exception import SSHException -from file import BufferedFile -import pipe +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException +from paramiko.file import BufferedFile +from paramiko import pipe class Channel (object): @@ -144,12 +148,7 @@ m.add_string('') self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + return self._wait_for_event() def invoke_shell(self): """ @@ -176,12 +175,7 @@ m.add_boolean(1) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + return self._wait_for_event() def exec_command(self, command): """ @@ -208,12 +202,7 @@ m.add_string(command) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + return self._wait_for_event() def invoke_subsystem(self, subsystem): """ @@ -239,12 +228,7 @@ m.add_string(subsystem) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + return self._wait_for_event() def resize_pty(self, width=80, height=24): """ @@ -270,12 +254,7 @@ m.add_int(0).add_int(0) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + self._wait_for_event() def recv_exit_status(self): """ @@ -292,8 +271,9 @@ """ while True: if self.closed or self.status_event.isSet(): - return self.exit_status + break self.status_event.wait(0.1) + return self.exit_status def send_exit_status(self, status): """ @@ -356,8 +336,6 @@ @return: the ID of this channel. @rtype: int - - @since: ivysaur """ return self.chanid @@ -453,6 +431,18 @@ else: self.settimeout(0.0) + def getpeername(self): + """ + Return the address of the remote side of this Channel, if possible. + This is just a wrapper around C{'getpeername'} on the Transport, used + to provide enough of a socket-like interface to allow asyncore to work. + (asyncore likes to call C{'getpeername'}.) + + @return: the address if the remote host, if known + @rtype: tuple(str, int) + """ + return self.transport.getpeername() + def close(self): """ Close the channel. All future read/write operations on the channel @@ -464,7 +454,7 @@ try: if not self.active or self.closed: return - self._close_internal() + msgs = self._close_internal() # only close the pipe when the user explicitly closes the channel. # otherwise they will get unpleasant surprises. @@ -473,6 +463,9 @@ self.pipe = None finally: self.lock.release() + for m in msgs: + if m is not None: + self.transport._send_user_message(m) def recv_ready(self): """ @@ -529,15 +522,24 @@ if len(self.in_buffer) <= nbytes: out = self.in_buffer self.in_buffer = '' - if self.pipe is not None: + if (self.pipe is not None) and not (self.closed or self.eof_received): # clear the pipe, since no more data is buffered self.pipe.clear() else: out = self.in_buffer[:nbytes] self.in_buffer = self.in_buffer[nbytes:] - self._check_add_window(len(out)) + ack = self._check_add_window(len(out)) finally: self.lock.release() + + # no need to hold the channel lock when sending this + if ack > 0: + m = Message() + m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) + m.add_int(self.remote_chanid) + m.add_int(ack) + self.transport._send_user_message(m) + return out def recv_stderr_ready(self): @@ -576,7 +578,7 @@ @rtype: str @raise socket.timeout: if no data is ready before the timeout set by - L{settimeout}. + L{settimeout}. @since: 1.1 """ @@ -624,7 +626,7 @@ @rtype: int @raise socket.timeout: if no data could be sent before the timeout set - by L{settimeout}. + by L{settimeout}. """ size = len(s) self.lock.acquire() @@ -657,7 +659,7 @@ @rtype: int @raise socket.timeout: if no data could be sent before the timeout set - by L{settimeout}. + by L{settimeout}. @since: 1.1 """ @@ -804,9 +806,11 @@ if (how == 1) or (how == 2): self.lock.acquire() try: - self._send_eof() + m = self._send_eof() finally: self.lock.release() + if m is not None: + self.transport._send_user_message(m) def shutdown_read(self): """ @@ -863,9 +867,12 @@ def _request_failed(self, m): self.lock.acquire() try: - self._close_internal() + msgs = self._close_internal() finally: self.lock.release() + for m in msgs: + if m is not None: + self.transport._send_user_message(m) def _feed(self, m): if type(m) is str: @@ -880,7 +887,7 @@ if self.pipe is not None: self.pipe.set() self.in_buffer += s - self.in_buffer_cv.notifyAll() + self.in_buffer_cv.notifyAll() finally: self.lock.release() @@ -983,7 +990,7 @@ self.in_buffer_cv.notifyAll() self.in_stderr_buffer_cv.notifyAll() if self.pipe is not None: - self.pipe.set() + self.pipe.set_forever() finally: self.lock.release() self._log(DEBUG, 'EOF received') @@ -991,10 +998,13 @@ def _handle_close(self, m): self.lock.acquire() try: - self._close_internal() + msgs = self._close_internal() self.transport._unlink_channel(self.chanid) finally: self.lock.release() + for m in msgs: + if m is not None: + self.transport._send_user_message(m) ### internals... @@ -1003,40 +1013,47 @@ def _log(self, level, msg): self.logger.log(level, msg) + def _wait_for_event(self): + while True: + self.event.wait(0.1) + if self.event.isSet(): + break + if self.closed: + return False + return True + def _set_closed(self): # you are holding the lock. self.closed = True self.in_buffer_cv.notifyAll() self.in_stderr_buffer_cv.notifyAll() self.out_buffer_cv.notifyAll() + if self.pipe is not None: + self.pipe.set_forever() def _send_eof(self): # you are holding the lock. if self.eof_sent: - return + return None m = Message() m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_int(self.remote_chanid) - self.transport._send_user_message(m) self.eof_sent = True self._log(DEBUG, 'EOF sent') - return + return m def _close_internal(self): # you are holding the lock. if not self.active or self.closed: - return - try: - self._send_eof() - m = Message() - m.add_byte(chr(MSG_CHANNEL_CLOSE)) - m.add_int(self.remote_chanid) - self.transport._send_user_message(m) - except EOFError: - pass + return None, None + m1 = self._send_eof() + m2 = Message() + m2.add_byte(chr(MSG_CHANNEL_CLOSE)) + m2.add_int(self.remote_chanid) self._set_closed() # can't unlink from the Transport yet -- the remote side may still # try to send meta-data (exit-status, etc) + return m1, m2 def _unlink(self): # server connection could die before we become active: still signal the close! @@ -1052,19 +1069,17 @@ def _check_add_window(self, n): # already holding the lock! if self.closed or self.eof_received or not self.active: - return + return 0 if self.ultra_debug: self._log(DEBUG, 'addwindow %d' % n) self.in_window_sofar += n - if self.in_window_sofar > self.in_window_threshold: - if self.ultra_debug: - self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar) - m = Message() - m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) - m.add_int(self.remote_chanid) - m.add_int(self.in_window_sofar) - self.transport._send_user_message(m) - self.in_window_sofar = 0 + if self.in_window_sofar <= self.in_window_threshold: + return 0 + if self.ultra_debug: + self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar) + out = self.in_window_sofar + self.in_window_sofar = 0 + return out def _wait_for_send_window(self, size): """ @@ -1105,8 +1120,6 @@ return size - - class ChannelFile (BufferedFile): """ A file-like wrapper around L{Channel}. A ChannelFile is created by calling @@ -1137,8 +1150,6 @@ def _write(self, data): self.channel.sendall(data) return len(data) - - seek = BufferedFile.seek class ChannelStderrFile (ChannelFile): ============================================================ --- paramiko/common.py 1518442b23a261749bc0eb223d7ebc2b54eaf3f1 +++ paramiko/common.py 734aa55bfaabc6b0902db4d432ccfb0dc22666b3 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -28,12 +26,14 @@ MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \ MSG_USERAUTH_BANNER = range(50, 54) MSG_USERAUTH_PK_OK = 60 +MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62) MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83) MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \ MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) + # for debugging: MSG_NAMES = { MSG_DISCONNECT: 'disconnect', @@ -53,7 +53,8 @@ MSG_USERAUTH_FAILURE: 'userauth-failure', MSG_USERAUTH_SUCCESS: 'userauth-success', MSG_USERAUTH_BANNER: 'userauth--banner', - MSG_USERAUTH_PK_OK: 'userauth-pk-ok', + MSG_USERAUTH_PK_OK: 'userauth-60(pk-ok/info-request)', + MSG_USERAUTH_INFO_RESPONSE: 'userauth-info-response', MSG_GLOBAL_REQUEST: 'global-request', MSG_REQUEST_SUCCESS: 'request-success', MSG_REQUEST_FAILURE: 'request-failure', @@ -70,19 +71,19 @@ MSG_CHANNEL_FAILURE: 'channel-failure' } -# authentication request return codes: +# authentication request return codes: AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3) # channel request failed reasons: - (OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE) = range(0, 5) + CONNECTION_FAILED_CODE = { 1: 'Administratively prohibited', 2: 'Connect failed', @@ -95,19 +96,22 @@ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14 - from Crypto.Util.randpool import PersistentRandomPool, RandomPool # keep a crypto-strong PRNG nearby +import os try: randpool = PersistentRandomPool(os.path.join(os.path.expanduser('~'), '/.randpool')) except: # the above will likely fail on Windows - fall back to non-persistent random pool randpool = RandomPool() -randpool.randomize() +try: + randpool.randomize() +except: + # earlier versions of pyCrypto (pre-2.0) don't have randomize() + pass - import sys if sys.version_info < (2, 3): try: @@ -126,6 +130,7 @@ import logging PY22 = False + DEBUG = logging.DEBUG INFO = logging.INFO WARNING = logging.WARNING ============================================================ --- paramiko/dsskey.py 92d4afa51af1c4a11fbdc52f1ac49516bd17a68f +++ paramiko/dsskey.py 96eafc3818ff93ec9b49432ad77b490baabd76e3 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -25,18 +23,25 @@ from Crypto.PublicKey import DSA from Crypto.Hash import SHA -from common import * -import util -from ssh_exception import SSHException -from message import Message -from ber import BER, BERException -from pkey import PKey +from paramiko.common import * +from paramiko import util +from paramiko.ssh_exception import SSHException +from paramiko.message import Message +from paramiko.ber import BER, BERException +from paramiko.pkey import PKey + class DSSKey (PKey): """ Representation of a DSS key which can be used to sign an verify SSH2 data. """ + + p = None + q = None + g = None + y = None + x = None def __init__(self, msg=None, data=None, filename=None, password=None, vals=None): if filename is not None: @@ -82,21 +87,28 @@ return self.size def can_sign(self): - return hasattr(self, 'x') + return self.x is not None def sign_ssh_data(self, rpool, data): digest = SHA.new(data).digest() dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q), long(self.x))) # generate a suitable k qsize = len(util.deflate_long(self.q, 0)) - while 1: + while True: k = util.inflate_long(rpool.get_bytes(qsize), 1) if (k > 2) and (k < self.q): break r, s = dss.sign(util.inflate_long(digest, 1), k) m = Message() m.add_string('ssh-dss') - m.add_string(util.deflate_long(r, 0) + util.deflate_long(s, 0)) + # apparently, in rare cases, r or s may be shorter than 20 bytes! + rstr = util.deflate_long(r, 0) + sstr = util.deflate_long(s, 0) + if len(rstr) < 20: + rstr = '\x00' * (20 - len(rstr)) + rstr + if len(sstr) < 20: + sstr = '\x00' * (20 - len(sstr)) + sstr + m.add_string(rstr + sstr) return m def verify_ssh_sig(self, data, msg): @@ -118,6 +130,8 @@ return dss.verify(sigM, (sigR, sigS)) def write_private_key_file(self, filename, password=None): + if self.x is None: + raise SSHException('Not enough key information') keylist = [ 0, self.p, self.q, self.g, self.y, self.x ] try: b = BER() @@ -134,13 +148,12 @@ @param bits: number of bits the generated key should be. @type bits: int @param progress_func: an optional function to call at key points in - key generation (used by C{pyCrypto.PublicKey}). + key generation (used by C{pyCrypto.PublicKey}). @type progress_func: function @return: new private key @rtype: L{DSSKey} - - @since: fearow """ + randpool.stir() dsa = DSA.generate(bits, randpool.get_bytes, progress_func) key = DSSKey(vals=(dsa.p, dsa.q, dsa.g, dsa.y)) key.x = dsa.x @@ -167,4 +180,3 @@ self.y = keylist[4] self.x = keylist[5] self.size = util.bit_length(self.p) - ============================================================ --- paramiko/file.py 036024e929348a5bce74525520e0c48465ee79a4 +++ paramiko/file.py d2478f2bd6e9643a4a54ae8ee0573400d5e96e11 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -24,6 +22,7 @@ from cStringIO import StringIO + _FLAG_READ = 0x1 _FLAG_WRITE = 0x2 _FLAG_APPEND = 0x4 @@ -32,6 +31,7 @@ _FLAG_LINE_BUFFERED = 0x40 _FLAG_UNIVERSAL_NEWLINE = 0x80 + class BufferedFile (object): """ Reusable base class to implement python-style file buffering around a @@ -45,6 +45,7 @@ SEEK_END = 2 def __init__(self): + self.newlines = None self._flags = 0 self._bufsize = self._DEFAULT_BUFSIZE self._wbuffer = StringIO() @@ -55,6 +56,8 @@ # realpos - position according the OS # (these may be different because we buffer for line reading) self._pos = self._realpos = 0 + # size only matters for seekable files + self._size = 0 def __del__(self): self.close() @@ -144,8 +147,11 @@ self._pos += len(result) return result while len(self._rbuffer) < size: + read_size = size - len(self._rbuffer) + if self._flags & _FLAG_BUFFERED: + read_size = max(self._bufsize, read_size) try: - new_data = self._read(max(self._bufsize, size - len(self._rbuffer))) + new_data = self._read(read_size) except EOFError: new_data = None if (new_data is None) or (len(new_data) == 0): @@ -202,7 +208,7 @@ return line n = size - len(line) else: - n = self._DEFAULT_BUFSIZE + n = self._bufsize if ('\n' in line) or ((self._flags & _FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): break try: @@ -321,7 +327,7 @@ return # even if we're line buffering, if the buffer has grown past the # buffer size, force a flush. - if len(self._wbuffer.getvalue()) >= self._bufsize: + if self._wbuffer.tell() >= self._bufsize: self.flush() return ============================================================ --- paramiko/kex_gex.py f6e75b3a0fe6e7f9cb7f5ef1a1ee2fe85c72a8c8 +++ paramiko/kex_gex.py 0e065c7cf5e983857f321004230c2992fb5b3f90 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -27,10 +25,10 @@ from Crypto.Hash import SHA from Crypto.Util import number -from common import * -from message import Message -import util -from ssh_exception import SSHException +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(31, 35) @@ -45,6 +43,12 @@ def __init__(self, transport): self.transport = transport + self.p = None + self.q = None + self.g = None + self.x = None + self.e = None + self.f = None def start_kex(self): if self.transport.server_mode: @@ -119,6 +123,7 @@ pack = self.transport._get_modulus_pack() if pack is None: raise SSHException('Can\'t do server-side gex with no modulus pack') + self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits)) self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) m = Message() m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) @@ -201,6 +206,3 @@ self.transport._set_K_H(K, SHA.new(str(hm)).digest()) self.transport._verify_key(host_key, sig) self.transport._activate_outbound() - - - ============================================================ --- paramiko/kex_group1.py e105620e8553e1920a2cad121992a11a1335cc68 +++ paramiko/kex_group1.py 927b0fcbf8ab634cba6ff363bea219e888012df9 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -23,11 +23,12 @@ from Crypto.Hash import SHA -from common import * -import util -from message import Message -from ssh_exception import SSHException +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException + _MSG_KEXDH_INIT, _MSG_KEXDH_REPLY = range(30, 32) # draft-ietf-secsh-transport-09.txt, page 17 ============================================================ --- paramiko/logging22.py d2073500a0e5fcebcd51a1abaf3d64363d7fc8bd +++ paramiko/logging22.py a5f736263bc246fd839855952a9c7d1b6fbbccdc @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -22,15 +20,18 @@ Stub out logging on python < 2.3. """ + DEBUG = 10 INFO = 20 WARNING = 30 ERROR = 40 CRITICAL = 50 + def getLogger(name): return _logger + class logger (object): def __init__(self): self.handlers = [ ] @@ -42,6 +43,9 @@ def addHandler(self, h): self.handlers.append(h) + def addFilter(self, filter): + pass + def log(self, level, text): if level >= self.level: for h in self.handlers: ============================================================ --- paramiko/message.py 677e83f4c16648442cf78ae84464bdd60e528033 +++ paramiko/message.py 5d9ed3d6fb92eb3e0968b9be7b114a7d7771060e @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -20,15 +20,21 @@ Implementation of an SSH2 "message". """ -import struct, cStringIO -import util +import struct +import cStringIO +from paramiko import util + class Message (object): """ An SSH2 I{Message} is a stream of bytes that encodes some combination of strings, integers, bools, and infinite-precision integers (known in python as I{long}s). This class builds or breaks down such a byte stream. + + Normally you don't need to deal with anything this low-level, but it's + exposed for people implementing custom extensions, or features that + paramiko doesn't support yet. """ def __init__(self, content=None): @@ -178,14 +184,32 @@ return self.get_string().split(',') def add_bytes(self, b): + """ + Write bytes to the stream, without any formatting. + + @param b: bytes to add + @type b: str + """ self.packet.write(b) return self def add_byte(self, b): + """ + Write a single byte to the stream, without any formatting. + + @param b: byte to add + @type b: str + """ self.packet.write(b) return self def add_boolean(self, b): + """ + Add a boolean value to the stream. + + @param b: boolean value to add + @type b: bool + """ if b: self.add_byte('\x01') else: @@ -193,6 +217,12 @@ return self def add_int(self, n): + """ + Add an integer to the stream. + + @param n: integer to add + @type n: int + """ self.packet.write(struct.pack('>I', n)) return self @@ -200,23 +230,43 @@ """ Add a 64-bit int to the stream. - @param n: long int to add. + @param n: long int to add @type n: long """ self.packet.write(struct.pack('>Q', n)) return self def add_mpint(self, z): - "this only works on positive numbers" + """ + Add a long int to the stream, encoded as an infinite-precision + integer. This method only works on positive numbers. + + @param z: long int to add + @type z: long + """ self.add_string(util.deflate_long(z)) return self def add_string(self, s): + """ + Add a string to the stream. + + @param s: string to add + @type s: str + """ self.add_int(len(s)) self.packet.write(s) return self def add_list(self, l): + """ + Add a list of strings to the stream. They are encoded identically to + a single string of values separated by commas. (Yes, really, that's + how SSH2 does it.) + + @param l: list of strings to add + @type l: list(str) + """ self.add_string(','.join(l)) return self @@ -235,8 +285,17 @@ elif type(i) is list: return self.add_list(i) else: - raise exception('Unknown type') + raise Exception('Unknown type') def add(self, *seq): + """ + Add a sequence of items to the stream. The values are encoded based + on their type: str, int, bool, list, or long. + + @param seq: the sequence of items + @type seq: sequence + + @bug: longs are encoded non-deterministically. Don't use this method. + """ for item in seq: self._add(item) ============================================================ --- paramiko/packet.py 92dd579d5f50b35110c71e86311bbf172c6804f1 +++ paramiko/packet.py 370e998149a0518964582af9357936c51c84c685 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -20,14 +20,35 @@ Packetizer. """ -import select, socket, struct, threading, time -from Crypto.Hash import HMAC -from common import * -from ssh_exception import SSHException -from message import Message -import util +import select +import socket +import struct +import threading +import time +from paramiko.common import * +from paramiko import util +from paramiko.ssh_exception import SSHException +from paramiko.message import Message + +got_r_hmac = False +try: + import r_hmac + got_r_hmac = True +except ImportError: + pass +def compute_hmac(key, message, digest_class): + if got_r_hmac: + return r_hmac.HMAC(key, message, digest_class).digest() + from Crypto.Hash import HMAC + return HMAC.HMAC(key, message, digest_class).digest() + + +class NeedRekeyException (Exception): + pass + + class Packetizer (object): """ Implementation of the base SSH packet protocol. @@ -45,6 +66,7 @@ self.__dump_packets = False self.__need_rekey = False self.__init_count = 0 + self.__remainder = '' # used for noticing when to re-key: self.__sent_bytes = 0 @@ -52,7 +74,7 @@ self.__received_bytes = 0 self.__received_packets = 0 self.__received_packets_overflow = 0 - + # current inbound/outbound ciphering: self.__block_size_out = 8 self.__block_size_in = 8 @@ -64,6 +86,8 @@ self.__mac_engine_in = None self.__mac_key_out = '' self.__mac_key_in = '' + self.__compress_engine_out = None + self.__compress_engine_in = None self.__sequence_number_out = 0L self.__sequence_number_in = 0L @@ -75,13 +99,6 @@ self.__keepalive_last = time.time() self.__keepalive_callback = None - def __del__(self): - # this is not guaranteed to be called, but we should try. - try: - self.__socket.close() - except: - pass - def set_log(self, log): """ Set the python log object to use for logging. @@ -123,6 +140,12 @@ self.__init_count = 0 self.__need_rekey = False + def set_outbound_compressor(self, compressor): + self.__compress_engine_out = compressor + + def set_inbound_compressor(self, compressor): + self.__compress_engine_in = compressor + def close(self): self.__closed = True @@ -158,7 +181,7 @@ self.__keepalive_callback = callback self.__keepalive_last = time.time() - def read_all(self, n): + def read_all(self, n, check_rekey=False): """ Read as close to N bytes as possible, blocking as long as necessary. @@ -169,9 +192,14 @@ @raise EOFError: if the socket was closed before all the bytes could be read """ + out = '' + # handle over-reading from reading the banner line + if len(self.__remainder) > 0: + out = self.__remainder[:n] + self.__remainder = self.__remainder[n:] + n -= len(out) if PY22: - return self._py22_read_all(n) - out = '' + return self._py22_read_all(n, out) while n > 0: try: x = self.__socket.recv(n) @@ -182,6 +210,8 @@ except socket.timeout: if self.__closed: raise EOFError() + if check_rekey and (len(out) == 0) and self.__need_rekey: + raise NeedRekeyException() self._check_keepalive() return out @@ -200,20 +230,21 @@ if n < 0: raise EOFError() if n == len(out): - return + break out = out[n:] return def readline(self, timeout): """ - Read a line from the socket. This is done in a fairly inefficient - way, but is only used for initial banner negotiation so it's not worth - optimising. + Read a line from the socket. We assume no data is pending after the + line, so it's okay to attempt large reads. """ buf = '' while not '\n' in buf: buf += self._read_timeout(timeout) - buf = buf[:-1] + n = buf.index('\n') + self.__remainder += buf[n+1:] + buf = buf[:n] if (len(buf) > 0) and (buf[-1] == '\r'): buf = buf[:-1] return buf @@ -229,9 +260,12 @@ cmd_name = MSG_NAMES[cmd] else: cmd_name = '$%x' % cmd - self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data))) + orig_len = len(data) + if self.__compress_engine_out is not None: + data = self.__compress_engine_out(data) packet = self._build_packet(data) if self.__dump_packets: + self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, orig_len)) self._log(DEBUG, util.format_binary(packet, 'OUT: ')) self.__write_lock.acquire() try: @@ -242,12 +276,15 @@ # + mac if self.__block_engine_out != None: payload = struct.pack('>I', self.__sequence_number_out) + packet - out += HMAC.HMAC(self.__mac_key_out, payload, self.__mac_engine_out).digest()[:self.__mac_size_out] + out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out] self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL self.write_all(out) self.__sent_bytes += len(out) self.__sent_packets += 1 + if (self.__sent_packets % 100) == 0: + # stirring the randpool takes 30ms on my ibook!! + randpool.stir() if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \ and not self.__need_rekey: # only ask once for rekeying @@ -264,8 +301,9 @@ done). @raise SSHException: if the packet is mangled + @raise NeedRekeyException: if the transport should rekey """ - header = self.read_all(self.__block_size_in) + header = self.read_all(self.__block_size_in, check_rekey=True) if self.__block_engine_in != None: header = self.__block_engine_in.decrypt(header) if self.__dump_packets: @@ -287,15 +325,18 @@ if self.__mac_size_in > 0: mac = post_packet[:self.__mac_size_in] mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet - my_mac = HMAC.HMAC(self.__mac_key_in, mac_payload, self.__mac_engine_in).digest()[:self.__mac_size_in] + my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in] if my_mac != mac: raise SSHException('Mismatched MAC') padding = ord(packet[0]) - payload = packet[1:packet_size - padding + 1] - randpool.add_event(packet[packet_size - padding + 1]) + payload = packet[1:packet_size - padding] + randpool.add_event() if self.__dump_packets: self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) + if self.__compress_engine_in is not None: + payload = self.__compress_engine_in(payload) + msg = Message(payload[1:]) msg.seqno = self.__sequence_number_in self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL @@ -322,7 +363,8 @@ cmd_name = MSG_NAMES[cmd] else: cmd_name = '$%x' % cmd - self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) + if self.__dump_packets: + self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) return cmd, msg @@ -348,8 +390,7 @@ self.__keepalive_callback() self.__keepalive_last = now - def _py22_read_all(self, n): - out = '' + def _py22_read_all(self, n, out): while n > 0: r, w, e = select.select([self.__socket], [], [], 0.1) if self.__socket not in r: @@ -372,23 +413,24 @@ x = self.__socket.recv(1) if len(x) == 0: raise EOFError() - return x + break if self.__closed: raise EOFError() now = time.time() if now - start >= timeout: raise socket.timeout() + return x def _read_timeout(self, timeout): if PY22: - return self._py22_read_timeout(n) + return self._py22_read_timeout(timeout) start = time.time() while True: try: - x = self.__socket.recv(1) + x = self.__socket.recv(128) if len(x) == 0: raise EOFError() - return x + break except socket.timeout: pass if self.__closed: @@ -396,6 +438,7 @@ now = time.time() if now - start >= timeout: raise socket.timeout() + return x def _build_packet(self, payload): # pad up at least 4 bytes, to nearest block-size (usually 8) ============================================================ --- paramiko/pipe.py 78054bd0d865e230464bd781abe2617c2f2c7136 +++ paramiko/pipe.py 07a33dcad6057372ed4e9be5743b81938753f390 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -28,14 +28,17 @@ def make_pipe (): if sys.platform[:3] != 'win': - return PosixPipe() - return WindowsPipe() + p = PosixPipe() + else: + p = WindowsPipe() + return p class PosixPipe (object): def __init__ (self): self._rfd, self._wfd = os.pipe() self._set = False + self._forever = False def close (self): os.close(self._rfd) @@ -45,7 +48,7 @@ return self._rfd def clear (self): - if not self._set: + if not self._set or self._forever: return os.read(self._rfd, 1) self._set = False @@ -55,6 +58,10 @@ return self._set = True os.write(self._wfd, '*') + + def set_forever (self): + self._forever = True + self.set() class WindowsPipe (object): @@ -67,13 +74,14 @@ serv.bind(('127.0.0.1', 0)) serv.listen(1) - # need to save sockets in pipe_rsock/pipe_wsock so they don't get closed + # need to save sockets in _rsock/_wsock so they don't get closed self._rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._rsock.connect(('127.0.0.1', serv.getsockname()[1])) self._wsock, addr = serv.accept() serv.close() self._set = False + self._forever = False def close (self): self._rsock.close() @@ -83,7 +91,7 @@ return self._rsock.fileno() def clear (self): - if not self._set: + if not self._set or self._forever: return self._rsock.recv(1) self._set = False @@ -93,3 +101,7 @@ return self._set = True self._wsock.send('*') + + def set_forever (self): + self._forever = True + self.set() ============================================================ --- paramiko/pkey.py 1c7280154144125bf970c9fabf4eead4fdddf0e3 +++ paramiko/pkey.py e6feb29295376dbe3a2a58ef37f1d60be4f0eb99 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -20,15 +20,16 @@ Common API for all public keys. """ -import os, base64 +import base64 +import os from Crypto.Hash import MD5 from Crypto.Cipher import DES3 -from common import * -from message import Message -from ssh_exception import SSHException, PasswordRequiredException -import util +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException, PasswordRequiredException class PKey (object): @@ -138,8 +139,6 @@ @return: a base64 string containing the public part of the key. @rtype: str - - @since: fearow """ return base64.encodestring(str(self)).replace('\n', '') @@ -172,7 +171,7 @@ """ return False - def from_private_key_file(cl, filename, password=None): + def from_private_key_file(cls, filename, password=None): """ Create a key object by reading a private key file. If the private key is encrypted and C{password} is not C{None}, the given password @@ -193,10 +192,8 @@ @raise PasswordRequiredException: if the private key file is encrypted, and C{password} is C{None}. @raise SSHException: if the key file is invalid. - - @since: fearow """ - key = cl(filename=filename, password=password) + key = cls(filename=filename, password=password) return key from_private_key_file = classmethod(from_private_key_file) @@ -212,10 +209,8 @@ @raise IOError: if there was an error writing the file. @raise SSHException: if the key is invalid. - - @since: fearow """ - raise exception('Not implemented in PKey') + raise Exception('Not implemented in PKey') def _read_private_key_file(self, tag, filename, password=None): """ @@ -264,7 +259,7 @@ # if we trudged to the end of the file, just try to cope. try: data = base64.decodestring(''.join(lines[start:end])) - except binascii.Error, e: + except base64.binascii.Error, e: raise SSHException('base64 decoding error: ' + str(e)) if not headers.has_key('proc-type'): # unencryped: done ============================================================ --- paramiko/primes.py aa1fc3675f3e154678e0c961dda8a3fee7bd9a82 +++ paramiko/primes.py 1b1bd9084fef5f567bad5a4de971e4a6e2be5d86 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -21,9 +21,11 @@ """ from Crypto.Util import number -import util +from paramiko import util +from paramiko.ssh_exception import SSHException + def _generate_prime(bits, randpool): "primtive attempt at prime generation" hbyte_mask = pow(2, bits % 8) - 1 @@ -38,7 +40,8 @@ while not number.isPrime(n): n += 2 if util.bit_length(n) == bits: - return n + break + return n def _roll_random(rpool, n): "returns a random # from 0 to N-1" @@ -58,7 +61,8 @@ x = chr(ord(x[0]) & hbyte_mask) + x[1:] num = util.inflate_long(x, 1) if num < n: - return num + break + return num class ModulusPack (object): @@ -74,8 +78,8 @@ self.randpool = rpool def _parse_modulus(self, line): - timestamp, type, tests, tries, size, generator, modulus = line.split() - type = int(type) + timestamp, mod_type, tests, tries, size, generator, modulus = line.split() + mod_type = int(mod_type) tests = int(tests) tries = int(tries) size = int(size) @@ -86,7 +90,7 @@ # type 2 (meets basic structural requirements) # test 4 (more than just a small-prime sieve) # tries < 100 if test & 4 (at least 100 tries of miller-rabin) - if (type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)): + if (mod_type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)): self.discarded.append((modulus, 'does not meet basic requirements')) return if generator == 0: @@ -145,4 +149,3 @@ # now pick a random modulus of this bitsize n = _roll_random(self.randpool, len(self.pack[good])) return self.pack[good][n] - ============================================================ --- paramiko/rsakey.py fdb464f41b215adbeee620603ccc1da76cbabaad +++ paramiko/rsakey.py 747c8f53df69713d965fdf1e92ad3e75bc73c606 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -24,19 +24,26 @@ from Crypto.Hash import SHA, MD5 from Crypto.Cipher import DES3 -from common import * -from message import Message -from ber import BER, BERException -import util -from pkey import PKey -from ssh_exception import SSHException +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ber import BER, BERException +from paramiko.pkey import PKey +from paramiko.ssh_exception import SSHException + class RSAKey (PKey): """ Representation of an RSA key which can be used to sign and verify SSH2 data. """ + n = None + e = None + d = None + p = None + q = None + def __init__(self, msg=None, data=None, filename=None, password=None, vals=None): if filename is not None: self._from_private_key_file(filename, password) @@ -74,7 +81,7 @@ return self.size def can_sign(self): - return hasattr(self, 'd') + return self.d is not None def sign_ssh_data(self, rpool, data): digest = SHA.new(data).digest() @@ -89,14 +96,16 @@ if msg.get_string() != 'ssh-rsa': return False sig = util.inflate_long(msg.get_string(), True) - # verify the signature by SHA'ing the data and encrypting it using theŒ + # verify the signature by SHA'ing the data and encrypting it using the # public key. some wackiness ensues where we "pkcs1imify" the 20-byte # hash into a string as long as the RSA key. - hash = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True) + hash_obj = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True) rsa = RSA.construct((long(self.n), long(self.e))) - return rsa.verify(hash, (sig,)) + return rsa.verify(hash_obj, (sig,)) def write_private_key_file(self, filename, password=None): + if (self.p is None) or (self.q is None): + raise SSHException('Not enough key info to write private key file') keylist = [ 0, self.n, self.e, self.d, self.p, self.q, self.d % (self.p - 1), self.d % (self.q - 1), util.mod_inverse(self.q, self.p) ] @@ -119,9 +128,8 @@ @type progress_func: function @return: new private key @rtype: L{RSAKey} - - @since: fearow """ + randpool.stir() rsa = RSA.generate(bits, randpool.get_bytes, progress_func) key = RSAKey(vals=(rsa.e, rsa.n)) key.d = rsa.d @@ -161,4 +169,3 @@ self.p = keylist[4] self.q = keylist[5] self.size = util.bit_length(self.n) - ============================================================ --- paramiko/server.py 684ff520b01acecaa50cba1e770539c889680db8 +++ paramiko/server.py 6027d5d22d1f043130e894e82251f2cd0d4b64be @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -23,11 +21,50 @@ """ import threading -from common import * -import util -from transport import BaseTransport -from auth_transport import Transport +from paramiko.common import * +from paramiko import util + +class InteractiveQuery (object): + """ + A query (set of prompts) for a user during interactive authentication. + """ + + def __init__(self, name='', instructions='', *prompts): + """ + Create a new interactive query to send to the client. The name and + instructions are optional, but are generally displayed to the end + user. A list of prompts may be included, or they may be added via + the L{add_prompt} method. + + @param name: name of this query + @type name: str + @param instructions: user instructions (usually short) about this query + @type instructions: str + """ + self.name = name + self.instructions = instructions + self.prompts = [] + for x in prompts: + if (type(x) is str) or (type(x) is unicode): + self.add_prompt(x) + else: + self.add_prompt(x[0], x[1]) + + def add_prompt(self, prompt, echo=True): + """ + Add a prompt to this query. The prompt should be a (reasonably short) + string. Multiple prompts can be added to the same query. + + @param prompt: the user prompt + @type prompt: str + @param echo: C{True} (default) if the user's response should be echoed; + C{False} if not (for a password or similar) + @type echo: bool + """ + self.prompts.append((prompt, echo)) + + class ServerInterface (object): """ This class defines an interface for controlling the behavior of paramiko @@ -156,7 +193,7 @@ Return L{AUTH_FAILED} if the key is not accepted, L{AUTH_SUCCESSFUL} if the key is accepted and completes the authentication, or L{AUTH_PARTIALLY_SUCCESSFUL} if your - authentication is stateful, and this key is accepted for + authentication is stateful, and this password is accepted for authentication, but more authentication is required. (In this latter case, L{get_allowed_auths} will be called to report to the client what options it has for continuing the authentication.) @@ -167,18 +204,75 @@ The default implementation always returns L{AUTH_FAILED}. - @param username: the username of the authenticating client. + @param username: the username of the authenticating client @type username: str - @param key: the key object provided by the client. + @param key: the key object provided by the client @type key: L{PKey } @return: L{AUTH_FAILED} if the client can't authenticate with this key; L{AUTH_SUCCESSFUL} if it can; L{AUTH_PARTIALLY_SUCCESSFUL} if it can authenticate with - this key but must continue with authentication. + this key but must continue with authentication @rtype: int """ return AUTH_FAILED + + def check_auth_interactive(self, username, submethods): + """ + Begin an interactive authentication challenge, if supported. You + should override this method in server mode if you want to support the + C{"keyboard-interactive"} auth type, which requires you to send a + series of questions for the client to answer. + Return L{AUTH_FAILED} if this auth method isn't supported. Otherwise, + you should return an L{InteractiveQuery} object containing the prompts + and instructions for the user. The response will be sent via a call + to L{check_auth_interactive_response}. + + The default implementation always returns L{AUTH_FAILED}. + + @param username: the username of the authenticating client + @type username: str + @param submethods: a comma-separated list of methods preferred by the + client (usually empty) + @type submethods: str + @return: L{AUTH_FAILED} if this auth method isn't supported; otherwise + an object containing queries for the user + @rtype: int or L{InteractiveQuery} + """ + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + """ + Continue or finish an interactive authentication challenge, if + supported. You should override this method in server mode if you want + to support the C{"keyboard-interactive"} auth type. + + Return L{AUTH_FAILED} if the responses are not accepted, + L{AUTH_SUCCESSFUL} if the responses are accepted and complete + the authentication, or L{AUTH_PARTIALLY_SUCCESSFUL} if your + authentication is stateful, and this set of responses is accepted for + authentication, but more authentication is required. (In this latter + case, L{get_allowed_auths} will be called to report to the client what + options it has for continuing the authentication.) + + If you wish to continue interactive authentication with more questions, + you may return an L{InteractiveQuery} object, which should cause the + client to respond with more answers, calling this method again. This + cycle can continue indefinitely. + + The default implementation always returns L{AUTH_FAILED}. + + @param responses: list of responses from the client + @type responses: list(str) + @return: L{AUTH_FAILED} if the authentication fails; + L{AUTH_SUCCESSFUL} if it succeeds; + L{AUTH_PARTIALLY_SUCCESSFUL} if the interactive auth is + successful, but authentication must continue; otherwise an object + containing queries for the user + @rtype: int or L{InteractiveQuery} + """ + return AUTH_FAILED + def check_global_request(self, kind, msg): """ Handle a global request of the given C{kind}. This method is called @@ -286,7 +380,7 @@ subsystem. An example of a subsystem is C{sftp}. The default implementation checks for a subsystem handler assigned via - L{Transport.set_subsystem_handler }. + L{Transport.set_subsystem_handler}. If one has been set, the handler is invoked and this method returns C{True}. Otherwise it returns C{False}. @@ -338,7 +432,7 @@ """ Handler for a subsytem in server mode. If you create a subclass of this class and pass it to - L{Transport.set_subsystem_handler }, + L{Transport.set_subsystem_handler}, an object of this class will be created for each request for this subsystem. Each new object will be executed within its own new thread by calling L{start_subsystem}. @@ -349,8 +443,6 @@ authenticated and requests subsytem C{"mp3"}, an object of class C{MP3Handler} will be created, and L{start_subsystem} will be called on it from a new thread. - - @since: ivysaur """ def __init__(self, channel, name, server): """ @@ -409,7 +501,7 @@ @note: It is the responsibility of this method to exit if the underlying L{Transport} is closed. This can be done by checking - L{Transport.is_active } or noticing an EOF + L{Transport.is_active} or noticing an EOF on the L{Channel}. If this method loops forever without checking for this case, your python interpreter may refuse to exit because this thread will still be running. ============================================================ --- paramiko/sftp.py 6c58bfe7730954a3ddf680b5b463e8c0fff295c7 +++ paramiko/sftp.py f1b11aa12a263d400f803a759e3f77f4bdbe47d7 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -16,12 +16,15 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -import struct, socket -from common import * -import util -from channel import Channel -from message import Message +import socket +import struct +from paramiko.common import * +from paramiko import util +from paramiko.channel import Channel +from paramiko.message import Message + + CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, CMD_FSTAT, \ CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, CMD_REMOVE, CMD_MKDIR, \ CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, CMD_READLINK, CMD_SYMLINK \ @@ -110,16 +113,18 @@ return version def _send_server_version(self): + # winscp will freak out if the server sends version info before the + # client finishes sending INIT. + t, data = self._read_packet() + if t != CMD_INIT: + raise SFTPError('Incompatible sftp protocol') + version = struct.unpack('>I', data[:4])[0] # advertise that we support "check-file" extension_pairs = [ 'check-file', 'md5,sha1' ] msg = Message() msg.add_int(_VERSION) msg.add(*extension_pairs) self._send_packet(CMD_VERSION, str(msg)) - t, data = self._read_packet() - if t != CMD_INIT: - raise SFTPError('Incompatible sftp protocol') - version = struct.unpack('>I', data[:4])[0] return version def _log(self, level, msg): ============================================================ --- paramiko/sftp_attr.py 835dbd7baed2d3403bf555de6c2ea833a246c6f3 +++ paramiko/sftp_attr.py b78aaf2565868de77fde605dbabecce8091e2230 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -18,9 +16,10 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -import stat, time -from common import * -from sftp import * +import stat +import time +from paramiko.common import * +from paramiko.sftp import * class SFTPAttributes (object): @@ -52,6 +51,12 @@ Create a new (empty) SFTPAttributes object. All fields will be empty. """ self._flags = 0 + self.st_size = None + self.st_uid = None + self.st_gid = None + self.st_mode = None + self.st_atime = None + self.st_mtime = None self.attr = {} def from_stat(cls, obj, filename=None): @@ -115,13 +120,13 @@ def _pack(self, msg): self._flags = 0 - if hasattr(self, 'st_size'): + if self.st_size is not None: self._flags |= self.FLAG_SIZE - if hasattr(self, 'st_uid') or hasattr(self, 'st_gid'): + if (self.st_uid is not None) and (self.st_gid is not None): self._flags |= self.FLAG_UIDGID - if hasattr(self, 'st_mode'): + if self.st_mode is not None: self._flags |= self.FLAG_PERMISSIONS - if hasattr(self, 'st_atime') or hasattr(self, 'st_mtime'): + if (self.st_atime is not None) and (self.st_mtime is not None): self._flags |= self.FLAG_AMTIME if len(self.attr) > 0: self._flags |= self.FLAG_EXTENDED @@ -129,13 +134,13 @@ if self._flags & self.FLAG_SIZE: msg.add_int64(self.st_size) if self._flags & self.FLAG_UIDGID: - msg.add_int(getattr(self, 'st_uid', 0)) - msg.add_int(getattr(self, 'st_gid', 0)) + msg.add_int(self.st_uid) + msg.add_int(self.st_gid) if self._flags & self.FLAG_PERMISSIONS: msg.add_int(self.st_mode) if self._flags & self.FLAG_AMTIME: - msg.add_int(getattr(self, 'st_atime', 0)) - msg.add_int(getattr(self, 'st_mtime', 0)) + msg.add_int(self.st_atime) + msg.add_int(self.st_mtime) if self._flags & self.FLAG_EXTENDED: msg.add_int(len(self.attr)) for key, val in self.attr.iteritems(): @@ -145,15 +150,14 @@ def _debug_str(self): out = '[ ' - if hasattr(self, 'st_size'): + if self.st_size is not None: out += 'size=%d ' % self.st_size - if hasattr(self, 'st_uid') or hasattr(self, 'st_gid'): - out += 'uid=%d gid=%d ' % (getattr(self, 'st_uid', 0), getattr(self, 'st_gid', 0)) - if hasattr(self, 'st_mode'): + if (self.st_uid is not None) and (self.st_gid is not None): + out += 'uid=%d gid=%d ' % (self.st_uid, self.st_gid) + if self.st_mode is not None: out += 'mode=' + oct(self.st_mode) + ' ' - if hasattr(self, 'st_atime') or hasattr(self, 'st_mtime'): - out += 'atime=%d mtime=%d ' % (getattr(self, 'st_atime', 0), - getattr(self, 'st_mtime', 0)) + if (self.st_atime is not None) and (self.st_mtime is not None): + out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime) for k, v in self.attr.iteritems(): out += '"%s"=%r ' % (str(k), v) out += ']' @@ -172,8 +176,8 @@ def __str__(self): "create a unix-style long description of the file (like ls -l)" - if hasattr(self, 'permissions'): - kind = self.permissions & stat.S_IFMT + if self.st_mode is not None: + kind = stat.S_IFMT(self.st_mode) if kind == stat.S_IFIFO: ks = 'p' elif kind == stat.S_IFCHR: @@ -190,20 +194,21 @@ ks = 's' else: ks = '?' - ks += _rwx((self.permissions & 0700) >> 6, self.permissions & stat.S_ISUID) - ks += _rwx((self.permissions & 070) >> 3, self.permissions & stat.S_ISGID) - ks += _rwx(self.permissions & 7, self.permissions & stat.S_ISVTX, True) + ks += self._rwx((self.st_mode & 0700) >> 6, self.st_mode & stat.S_ISUID) + ks += self._rwx((self.st_mode & 070) >> 3, self.st_mode & stat.S_ISGID) + ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True) else: ks = '?---------' - uid = getattr(self, 'uid', -1) - gid = getattr(self, 'gid', -1) - size = getattr(self, 'size', -1) - mtime = getattr(self, 'mtime', 0) # compute display date - if abs(time.time() - mtime) > 15552000: - # (15552000 = 6 months) - datestr = time.strftime('%d %b %Y', time.localtime(mtime)) + if self.st_mtime is None: + # shouldn't really happen + datestr = '(unknown date)' else: - datestr = time.strftime('%d %b %H:%M', time.localtime(mtime)) + if abs(time.time() - self.st_mtime) > 15552000: + # (15552000 = 6 months) + datestr = time.strftime('%d %b %Y', time.localtime(self.st_mtime)) + else: + datestr = time.strftime('%d %b %H:%M', time.localtime(self.st_mtime)) filename = getattr(self, 'filename', '?') + return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, self.st_uid, self.st_gid, + self.st_size, datestr, filename) - return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, size, datestr, filename) ============================================================ --- paramiko/sftp_client.py 5bd2a4fe7cfc525df863b12e4dd585afd4d34bb9 +++ paramiko/sftp_client.py bcd91938416220067071abcec12e5b54c1278208 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -20,18 +20,29 @@ Client-mode SFTP support. """ +import errno import os -from sftp import * -from sftp_attr import SFTPAttributes -from sftp_file import SFTPFile +import threading +import time +import weakref +from paramiko.sftp import * +from paramiko.sftp_attr import SFTPAttributes +from paramiko.sftp_file import SFTPFile def _to_unicode(s): - "if a str is not ascii, decode its utf8 into unicode" + """ + decode a string as ascii or utf8 if possible (as required by the sftp + protocol). if neither works, just return a byte string because the server + probably doesn't know the filename's encoding. + """ try: return s.encode('ascii') - except: - return s.decode('utf-8') + except UnicodeError: + try: + return s.decode('utf-8') + except UnicodeError: + return s class SFTPClient (BaseSFTP): @@ -48,33 +59,35 @@ An alternate way to create an SFTP client context is by using L{from_transport}. - @param sock: an open L{Channel} using the C{"sftp"} subsystem. + @param sock: an open L{Channel} using the C{"sftp"} subsystem @type sock: L{Channel} """ BaseSFTP.__init__(self) self.sock = sock self.ultra_debug = False self.request_number = 1 + # lock for request_number + self._lock = threading.Lock() self._cwd = None + # request # -> SFTPFile + self._expecting = weakref.WeakValueDictionary() if type(sock) is Channel: # override default logger transport = self.sock.get_transport() self.logger = util.get_logger(transport.get_log_channel() + '.' + self.sock.get_name() + '.sftp') self.ultra_debug = transport.get_hexdump() - self._send_version() - - def __del__(self): - self.close() + server_version = self._send_version() + self._log(INFO, 'Opened sftp connection (server version %d)' % server_version) - def from_transport(selfclass, t): + def from_transport(cls, t): """ Create an SFTP client channel from an open L{Transport}. - @param t: an open L{Transport} which is already authenticated. + @param t: an open L{Transport} which is already authenticated @type t: L{Transport} @return: a new L{SFTPClient} object, referring to an sftp session - (channel) across the transport. + (channel) across the transport @rtype: L{SFTPClient} """ chan = t.open_session() @@ -82,7 +95,7 @@ return None if not chan.invoke_subsystem('sftp'): raise SFTPError('Failed to invoke sftp subsystem') - return selfclass(chan) + return cls(chan) from_transport = classmethod(from_transport) def close(self): @@ -91,6 +104,7 @@ @since: 1.4 """ + self._log(INFO, 'sftp session closed.') self.sock.close() def listdir(self, path='.'): @@ -123,6 +137,7 @@ @since: 1.2 """ path = self._adjust_cwd(path) + self._log(DEBUG, 'listdir(%r)' % path) t, msg = self._request(CMD_OPENDIR, path) if t != CMD_HANDLE: raise SFTPError('Expected handle') @@ -146,10 +161,10 @@ self._request(CMD_CLOSE, handle) return filelist - def open(self, filename, mode='r', bufsize=-1): + def file(self, filename, mode='r', bufsize=-1): """ Open a file on the remote server. The arguments are the same as for - python's built-in C{open} (aka C{file}). A file-like object is + python's built-in C{file} (aka C{open}). A file-like object is returned, which closely mimics the behavior of a normal python file object. @@ -159,55 +174,65 @@ existing file), C{'a+'} for reading/appending. The python C{'b'} flag is ignored, since SSH treats all files as binary. The C{'U'} flag is supported in a compatible way. + + Since 1.5.2, an C{'x'} flag indicates that the operation should only + succeed if the file was created and did not previously exist. This has + no direct mapping to python's file flags, but is commonly known as the + C{O_EXCL} flag in posix. The file will be buffered in standard python style by default, but can be altered with the C{bufsize} parameter. C{0} turns off buffering, C{1} uses line buffering, and any number greater than 1 (C{>1}) uses that specific buffer size. - @param filename: name of the file to open. - @type filename: string - @param mode: mode (python-style) to open in. - @type mode: string + @param filename: name of the file to open + @type filename: str + @param mode: mode (python-style) to open in + @type mode: str @param bufsize: desired buffering (-1 = default buffer size) @type bufsize: int - @return: a file object representing the open file. + @return: a file object representing the open file @rtype: SFTPFile @raise IOError: if the file could not be opened. """ filename = self._adjust_cwd(filename) + self._log(DEBUG, 'open(%r, %r)' % (filename, mode)) imode = 0 if ('r' in mode) or ('+' in mode): imode |= SFTP_FLAG_READ - if ('w' in mode) or ('+' in mode): + if ('w' in mode) or ('+' in mode) or ('a' in mode): imode |= SFTP_FLAG_WRITE if ('w' in mode): imode |= SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC if ('a' in mode): - imode |= SFTP_FLAG_APPEND | SFTP_FLAG_CREATE | SFTP_FLAG_WRITE + imode |= SFTP_FLAG_CREATE | SFTP_FLAG_APPEND + if ('x' in mode): + imode |= SFTP_FLAG_CREATE | SFTP_FLAG_EXCL attrblock = SFTPAttributes() t, msg = self._request(CMD_OPEN, filename, imode, attrblock) if t != CMD_HANDLE: raise SFTPError('Expected handle') handle = msg.get_string() + self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, util.hexify(handle))) return SFTPFile(self, handle, mode, bufsize) # python has migrated toward file() instead of open(). # and really, that's more easily identifiable. - file = open + open = file def remove(self, path): """ - Remove the file at the given path. + Remove the file at the given path. This only works on files; for + removing folders (directories), use L{rmdir}. - @param path: path (absolute or relative) of the file to remove. - @type path: string + @param path: path (absolute or relative) of the file to remove + @type path: str - @raise IOError: if the path refers to a folder (directory). Use - L{rmdir} to remove a folder. + @raise IOError: if the path refers to a folder (directory) """ path = self._adjust_cwd(path) + self._log(DEBUG, 'remove(%r)' % path) self._request(CMD_REMOVE, path) unlink = remove @@ -216,16 +241,17 @@ """ Rename a file or folder from C{oldpath} to C{newpath}. - @param oldpath: existing name of the file or folder. - @type oldpath: string - @param newpath: new name for the file or folder. - @type newpath: string + @param oldpath: existing name of the file or folder + @type oldpath: str + @param newpath: new name for the file or folder + @type newpath: str @raise IOError: if C{newpath} is a folder, or something else goes - wrong. + wrong """ oldpath = self._adjust_cwd(oldpath) newpath = self._adjust_cwd(newpath) + self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath)) self._request(CMD_RENAME, oldpath, newpath) def mkdir(self, path, mode=0777): @@ -234,12 +260,13 @@ The default mode is 0777 (octal). On some systems, mode is ignored. Where it is used, the current umask value is first masked out. - @param path: name of the folder to create. - @type path: string - @param mode: permissions (posix-style) for the newly-created folder. + @param path: name of the folder to create + @type path: str + @param mode: permissions (posix-style) for the newly-created folder @type mode: int """ path = self._adjust_cwd(path) + self._log(DEBUG, 'mkdir(%r)' % path) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_MKDIR, path, attr) @@ -248,10 +275,11 @@ """ Remove the folder named C{path}. - @param path: name of the folder to remove. - @type path: string + @param path: name of the folder to remove + @type path: str """ path = self._adjust_cwd(path) + self._log(DEBUG, 'rmdir(%r)' % path) self._request(CMD_RMDIR, path) def stat(self, path): @@ -268,12 +296,13 @@ The fields supported are: C{st_mode}, C{st_size}, C{st_uid}, C{st_gid}, C{st_atime}, and C{st_mtime}. - @param path: the filename to stat. - @type path: string - @return: an object containing attributes about the given file. + @param path: the filename to stat + @type path: str + @return: an object containing attributes about the given file @rtype: SFTPAttributes """ path = self._adjust_cwd(path) + self._log(DEBUG, 'stat(%r)' % path) t, msg = self._request(CMD_STAT, path) if t != CMD_ATTRS: raise SFTPError('Expected attributes') @@ -285,12 +314,13 @@ following symbolic links (shortcuts). This otherwise behaves exactly the same as L{stat}. - @param path: the filename to stat. - @type path: string - @return: an object containing attributes about the given file. + @param path: the filename to stat + @type path: str + @return: an object containing attributes about the given file @rtype: SFTPAttributes """ path = self._adjust_cwd(path) + self._log(DEBUG, 'lstat(%r)' % path) t, msg = self._request(CMD_LSTAT, path) if t != CMD_ATTRS: raise SFTPError('Expected attributes') @@ -301,12 +331,13 @@ Create a symbolic link (shortcut) of the C{source} path at C{destination}. - @param source: path of the original file. - @type source: string - @param dest: path of the newly created symlink. - @type dest: string + @param source: path of the original file + @type source: str + @param dest: path of the newly created symlink + @type dest: str """ dest = self._adjust_cwd(dest) + self._log(DEBUG, 'symlink(%r, %r)' % (source, dest)) if type(source) is unicode: source = source.encode('utf-8') self._request(CMD_SYMLINK, source, dest) @@ -317,12 +348,13 @@ unix-style and identical to those used by python's C{os.chmod} function. - @param path: path of the file to change the permissions of. - @type path: string - @param mode: new permissions. + @param path: path of the file to change the permissions of + @type path: str + @param mode: new permissions @type mode: int """ path = self._adjust_cwd(path) + self._log(DEBUG, 'chmod(%r, %r)' % (path, mode)) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_SETSTAT, path, attr) @@ -334,14 +366,15 @@ only want to change one, use L{stat} first to retrieve the current owner and group. - @param path: path of the file to change the owner and group of. - @type path: string + @param path: path of the file to change the owner and group of + @type path: str @param uid: new owner's uid @type uid: int @param gid: new group id @type gid: int """ path = self._adjust_cwd(path) + self._log(DEBUG, 'chown(%r, %r, %r)' % (path, uid, gid)) attr = SFTPAttributes() attr.st_uid, attr.st_gid = uid, gid self._request(CMD_SETSTAT, path, attr) @@ -355,31 +388,50 @@ modified times, respectively. This bizarre API is mimicked from python for the sake of consistency -- I apologize. - @param path: path of the file to modify. - @type path: string + @param path: path of the file to modify + @type path: str @param times: C{None} or a tuple of (access time, modified time) in - standard internet epoch time (seconds since 01 January 1970 GMT). - @type times: tuple of int + standard internet epoch time (seconds since 01 January 1970 GMT) + @type times: tuple(int) """ path = self._adjust_cwd(path) if times is None: times = (time.time(), time.time()) + self._log(DEBUG, 'utime(%r, %r)' % (path, times)) attr = SFTPAttributes() attr.st_atime, attr.st_mtime = times self._request(CMD_SETSTAT, path, attr) + def truncate(self, path, size): + """ + Change the size of the file specified by C{path}. This usually extends + or shrinks the size of the file, just like the C{truncate()} method on + python file objects. + + @param path: path of the file to modify + @type path: str + @param size: the new size of the file + @type size: int or long + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'truncate(%r, %r)' % (path, size)) + attr = SFTPAttributes() + attr.st_size = size + self._request(CMD_SETSTAT, path, attr) + def readlink(self, path): """ Return the target of a symbolic link (shortcut). You can use L{symlink} to create these. The result may be either an absolute or relative pathname. - @param path: path of the symbolic link file. + @param path: path of the symbolic link file @type path: str - @return: target path. + @return: target path @rtype: str """ path = self._adjust_cwd(path) + self._log(DEBUG, 'readlink(%r)' % path) t, msg = self._request(CMD_READLINK, path) if t != CMD_NAME: raise SFTPError('Expected name response') @@ -397,14 +449,15 @@ server is considering to be the "current folder" (by passing C{'.'} as C{path}). - @param path: path to be normalized. + @param path: path to be normalized @type path: str - @return: normalized form of the given path. + @return: normalized form of the given path @rtype: str @raise IOError: if the path can't be resolved on the server """ path = self._adjust_cwd(path) + self._log(DEBUG, 'normalize(%r)' % path) t, msg = self._request(CMD_REALPATH, path) if t != CMD_NAME: raise SFTPError('Expected name response') @@ -449,6 +502,8 @@ Any exception raised by operations will be passed through. This method is primarily provided as a convenience. + The SFTP operations use pipelining for speed. + @param localpath: the local file to copy @type localpath: str @param remotepath: the destination path on the SFTP server @@ -458,9 +513,10 @@ """ fl = file(localpath, 'rb') fr = self.file(remotepath, 'wb') + fr.set_pipelined(True) size = 0 while True: - data = fl.read(16384) + data = fl.read(32768) if len(data) == 0: break fr.write(data) @@ -485,10 +541,11 @@ @since: 1.4 """ fr = self.file(remotepath, 'rb') + fr.prefetch() fl = file(localpath, 'wb') size = 0 while True: - data = fr.read(16384) + data = fr.read(32768) if len(data) == 0: break fl.write(data) @@ -504,30 +561,65 @@ def _request(self, t, *arg): - msg = Message() - msg.add_int(self.request_number) - for item in arg: - if type(item) is int: - msg.add_int(item) - elif type(item) is long: - msg.add_int64(item) - elif type(item) is str: - msg.add_string(item) - elif type(item) is SFTPAttributes: - item._pack(msg) - else: - raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item))) - self._send_packet(t, str(msg)) - t, data = self._read_packet() - msg = Message(data) - num = msg.get_int() - if num != self.request_number: - raise SFTPError('Expected response #%d, got response #%d' % (self.request_number, num)) - self.request_number += 1 - if t == CMD_STATUS: - self._convert_status(msg) - return t, msg + num = self._async_request(type(None), t, *arg) + return self._read_response(num) + + def _async_request(self, fileobj, t, *arg): + # this method may be called from other threads (prefetch) + self._lock.acquire() + try: + msg = Message() + msg.add_int(self.request_number) + for item in arg: + if type(item) is int: + msg.add_int(item) + elif type(item) is long: + msg.add_int64(item) + elif type(item) is str: + msg.add_string(item) + elif type(item) is SFTPAttributes: + item._pack(msg) + else: + raise Exception('unknown type for %r type %r' % (item, type(item))) + num = self.request_number + self._expecting[num] = fileobj + self._send_packet(t, str(msg)) + self.request_number += 1 + finally: + self._lock.release() + return num + def _read_response(self, waitfor=None): + while True: + t, data = self._read_packet() + msg = Message(data) + num = msg.get_int() + if num not in self._expecting: + # might be response for a file that was closed before responses came back + self._log(DEBUG, 'Unexpected response #%d' % (num,)) + if waitfor is None: + # just doing a single check + break + continue + fileobj = self._expecting[num] + del self._expecting[num] + if num == waitfor: + # synchronous + if t == CMD_STATUS: + self._convert_status(msg) + return t, msg + if fileobj is not type(None): + fileobj._async_response(t, msg) + if waitfor is None: + # just doing a single check + break + return (None, None) + + def _finish_responses(self, fileobj): + while fileobj in self._expecting.values(): + self._read_response() + fileobj._check_exception() + def _convert_status(self, msg): """ Raises EOFError or IOError on error status; otherwise does nothing. @@ -538,6 +630,11 @@ return elif code == SFTP_EOF: raise EOFError(text) + elif code == SFTP_NO_SUCH_FILE: + # clever idea from john a. meinel: map the error codes to errno + raise IOError(errno.ENOENT, text) + elif code == SFTP_PERMISSION_DENIED: + raise IOError(errno.EACCES, text) else: raise IOError(text) ============================================================ --- paramiko/sftp_file.py b6da074f26d4d2e8f891d0eb823d8ac9e9d50121 +++ paramiko/sftp_file.py 1670b714a0025b13eea2e9d40d58dc46bdc508fe @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -20,10 +20,12 @@ L{SFTPFile} """ -from common import * -from sftp import * -from file import BufferedFile -from sftp_attr import SFTPAttributes +import socket +import threading +from paramiko.common import * +from paramiko.sftp import * +from paramiko.file import BufferedFile +from paramiko.sftp_attr import SFTPAttributes class SFTPFile (BufferedFile): @@ -40,11 +42,20 @@ self.sftp = sftp self.handle = handle BufferedFile._set_mode(self, mode, bufsize) + self.pipelined = False + self._prefetching = False + self._prefetch_done = False + self._prefetch_so_far = 0 + self._prefetch_data = {} + self._saved_exception = None def __del__(self): - self.close() + self._close(async=True) + + def close(self): + self._close(async=False) - def close(self): + def _close(self, async=False): # We allow double-close without signaling an error, because real # Python file objects do. However, we must protect against actually # sending multiple CMD_CLOSE packets, because after we close our @@ -54,18 +65,65 @@ # __del__.) if self._closed: return + self.sftp._log(DEBUG, 'close(%s)' % util.hexify(self.handle)) + if self.pipelined: + self.sftp._finish_responses(self) BufferedFile.close(self) try: - self.sftp._request(CMD_CLOSE, self.handle) + if async: + # GC'd file handle could be called from an arbitrary thread -- don't wait for a response + self.sftp._async_request(type(None), CMD_CLOSE, self.handle) + else: + self.sftp._request(CMD_CLOSE, self.handle) except EOFError: # may have outlived the Transport connection pass - except IOError: + except (IOError, socket.error): # may have outlived the Transport connection pass + def _read_prefetch(self, size): + """ + read data out of the prefetch buffer, if possible. if the data isn't + in the buffer, return None. otherwise, behaves like a normal read. + """ + # while not closed, and haven't fetched past the current position, and haven't reached EOF... + while (self._prefetch_so_far <= self._realpos) and not self._closed: + if self._prefetch_done: + return None + self.sftp._read_response() + self._check_exception() + k = self._prefetch_data.keys() + if len(k) == 0: + self._prefetching = False + return '' + + # find largest offset < realpos + pos_list = [i for i in k if i <= self._realpos] + if len(pos_list) == 0: + return None + index = max(pos_list) + prefetch = self._prefetch_data[index] + del self._prefetch_data[index] + + buf_offset = self._realpos - index + if buf_offset > 0: + self._prefetch_data[index] = prefetch[:buf_offset] + prefetch = prefetch[buf_offset:] + if buf_offset >= len(prefetch): + # it's not here. + return None + if size < len(prefetch): + self._prefetch_data[self._realpos + size] = prefetch[size:] + prefetch = prefetch[:size] + return prefetch + def _read(self, size): size = min(size, self.MAX_REQUEST_SIZE) + if self._prefetching: + data = self._read_prefetch(size) + if data is not None: + return data t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size)) if t != CMD_DATA: raise SFTPError('Expected data') @@ -74,11 +132,13 @@ def _write(self, data): # may write less than requested if it would exceed max packet size chunk = min(len(data), self.MAX_REQUEST_SIZE) - t, msg = self.sftp._request(CMD_WRITE, self.handle, long(self._realpos), - str(data[:chunk])) - if t != CMD_STATUS: - raise SFTPError('Expected status') - self.sftp._convert_status(msg) + req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), + str(data[:chunk])) + if not self.pipelined or self.sftp.sock.recv_ready(): + t, msg = self.sftp._read_response(req) + if t != CMD_STATUS: + raise SFTPError('Expected status') + # convert_status already called return chunk def settimeout(self, timeout): @@ -120,8 +180,8 @@ if whence == self.SEEK_SET: self._realpos = self._pos = offset elif whence == self.SEEK_CUR: - self._realpos += offset self._pos += offset + self._realpos = self._pos else: self._realpos = self._pos = self._get_size() + offset self._rbuffer = '' @@ -139,6 +199,71 @@ if t != CMD_ATTRS: raise SFTPError('Expected attributes') return SFTPAttributes._from_msg(msg) + + def chmod(self, mode): + """ + Change the mode (permissions) of this file. The permissions are + unix-style and identical to those used by python's C{os.chmod} + function. + + @param mode: new permissions + @type mode: int + """ + self.sftp._log(DEBUG, 'chmod(%s, %r)' % (util.hexify(self.handle), mode)) + attr = SFTPAttributes() + attr.st_mode = mode + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def chown(self, uid, gid): + """ + Change the owner (C{uid}) and group (C{gid}) of this file. As with + python's C{os.chown} function, you must pass both arguments, so if you + only want to change one, use L{stat} first to retrieve the current + owner and group. + + @param uid: new owner's uid + @type uid: int + @param gid: new group id + @type gid: int + """ + self.sftp._log(DEBUG, 'chown(%s, %r, %r)' % (util.hexify(self.handle), uid, gid)) + attr = SFTPAttributes() + attr.st_uid, attr.st_gid = uid, gid + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def utime(self, times): + """ + Set the access and modified times of this file. If + C{times} is C{None}, then the file's access and modified times are set + to the current time. Otherwise, C{times} must be a 2-tuple of numbers, + of the form C{(atime, mtime)}, which is used to set the access and + modified times, respectively. This bizarre API is mimicked from python + for the sake of consistency -- I apologize. + + @param times: C{None} or a tuple of (access time, modified time) in + standard internet epoch time (seconds since 01 January 1970 GMT) + @type times: tuple(int) + """ + if times is None: + times = (time.time(), time.time()) + self.sftp._log(DEBUG, 'utime(%s, %r)' % (util.hexify(self.handle), times)) + attr = SFTPAttributes() + attr.st_atime, attr.st_mtime = times + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def truncate(self, size): + """ + Change the size of this file. This usually extends + or shrinks the size of the file, just like the C{truncate()} method on + python file objects. + + @param size: the new size of the file + @type size: int or long + """ + self.sftp._log(DEBUG, 'truncate(%s, %r)' % (util.hexify(self.handle), size)) + attr = SFTPAttributes() + attr.st_size = size + self.sftp._request(CMD_FSETSTAT, self.handle, attr) def check(self, hash_algorithm, offset=0, length=0, block_size=0): """ @@ -193,7 +318,78 @@ alg = msg.get_string() data = msg.get_remainder() return data + + def set_pipelined(self, pipelined=True): + """ + Turn on/off the pipelining of write operations to this file. When + pipelining is on, paramiko won't wait for the server response after + each write operation. Instead, they're collected as they come in. + At the first non-write operation (including L{close}), all remaining + server responses are collected. This means that if there was an error + with one of your later writes, an exception might be thrown from + within L{close} instead of L{write}. + + By default, files are I{not} pipelined. + + @param pipelined: C{True} if pipelining should be turned on for this + file; C{False} otherwise + @type pipelined: bool + + @since: 1.5 + """ + self.pipelined = pipelined + + def prefetch(self): + """ + Pre-fetch the remaining contents of this file in anticipation of + future L{read} calls. If reading the entire file, pre-fetching can + dramatically improve the download speed by avoiding roundtrip latency. + The file's contents are incrementally buffered in a background thread. + + The prefetched data is stored in a buffer until read via the L{read} + method. Once data has been read, it's removed from the buffer. The + data may be read in a random order (using L{seek}); chunks of the + buffer that haven't been read will continue to be buffered. + @since: 1.5.1 + """ + size = self.stat().st_size + # queue up async reads for the rest of the file + chunks = [] + n = self._realpos + while n < size: + chunk = min(self.MAX_REQUEST_SIZE, size - n) + chunks.append((n, chunk)) + n += chunk + self._start_prefetch(chunks) + + def readv(self, chunks): + """ + Read a set of blocks from the file by (offset, length). This is more + efficient than doing a series of L{seek} and L{read} calls, since the + prefetch machinery is used to retrieve all the requested blocks at + once. + + @param chunks: a list of (offset, length) tuples indicating which + sections of the file to read + @ptype chunks: list(tuple(long, int)) + @return: a list of blocks read, in the same order as in C{chunks} + @rtype: list(str) + + @since: 1.5.4 + """ + # put the offsets in order, since we depend on that for determining + # when the reads have finished. + ordered_chunks = chunks[:] + ordered_chunks.sort(lambda x, y: cmp(x[0], y[0])) + self._start_prefetch(ordered_chunks) + # now we can just devolve to a bunch of read()s :) + out = [] + for x in chunks: + self.seek(x[0]) + out.append(self.read(x[1])) + return out + ### internals... @@ -203,3 +399,45 @@ return self.stat().st_size except: return 0 + + def _start_prefetch(self, chunks): + self._prefetching = True + self._prefetch_done = False + self._prefetch_so_far = chunks[0][0] + self._prefetch_data = {} + self._prefetch_reads = chunks[:] + + t = threading.Thread(target=self._prefetch_thread, args=(chunks,)) + t.setDaemon(True) + t.start() + + def _prefetch_thread(self, chunks): + # do these read requests in a temporary thread because there may be + # a lot of them, so it may block. + for offset, length in chunks: + self.sftp._async_request(self, CMD_READ, self.handle, long(offset), int(length)) + + def _async_response(self, t, msg): + if t == CMD_STATUS: + # save exception and re-raise it on next file operation + try: + self.sftp._convert_status(msg) + except Exception, x: + self._saved_exception = x + return + if t != CMD_DATA: + raise SFTPError('Expected data') + data = msg.get_string() + offset, length = self._prefetch_reads.pop(0) + assert length == len(data) + self._prefetch_data[offset] = data + self._prefetch_so_far = offset + length + if len(self._prefetch_reads) == 0: + self._prefetch_done = True + + def _check_exception(self): + "if there's a saved exception, raise & clear it" + if self._saved_exception is not None: + x = self._saved_exception + self._saved_exception = None + raise x ============================================================ --- paramiko/sftp_handle.py 839b45b7b379f6e726ff0c1017c4b2f0ea285344 +++ paramiko/sftp_handle.py f41290936813308293937538626ab33350e01f82 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -21,10 +21,11 @@ """ import os -from common import * -from sftp import * +from paramiko.common import * +from paramiko.sftp import * + class SFTPHandle (object): """ Abstract object representing a handle to an open file (or folder) in an @@ -80,15 +81,16 @@ @return: data read from the file, or an SFTP error code. @rtype: str """ - if not hasattr(self, 'readfile') or (self.readfile is None): + readfile = getattr(self, 'readfile', None) + if readfile is None: return SFTP_OP_UNSUPPORTED try: if self.__tell is None: - self.__tell = self.readfile.tell() + self.__tell = readfile.tell() if offset != self.__tell: - self.readfile.seek(offset) + readfile.seek(offset) self.__tell = offset - data = self.readfile.read(length) + data = readfile.read(length) except IOError, e: self.__tell = None return SFTPServer.convert_errno(e.errno) @@ -115,16 +117,17 @@ @type data: str @return: an SFTP error code like L{SFTP_OK}. """ - if not hasattr(self, 'writefile') or (self.writefile is None): + writefile = getattr(self, 'writefile', None) + if writefile is None: return SFTP_OP_UNSUPPORTED try: if self.__tell is None: - self.__tell = self.writefile.tell() + self.__tell = writefile.tell() if offset != self.__tell: - self.writefile.seek(offset) + writefile.seek(offset) self.__tell = offset - self.writefile.write(data) - self.writefile.flush() + writefile.write(data) + writefile.flush() except IOError, e: self.__tell = None return SFTPServer.convert_errno(e.errno) @@ -184,4 +187,4 @@ self.__name = name +from paramiko.sftp_server import SFTPServer -from sftp_server import SFTPServer ============================================================ --- paramiko/sftp_server.py fc13033c75ae152a5c3c5758ebb9ecc7f2b5edb8 +++ paramiko/sftp_server.py 73185d628c5417fcc59c965125f1aa0426d7e471 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -22,13 +20,15 @@ Server-mode SFTP support. """ -import os, errno +import os +import errno + from Crypto.Hash import MD5, SHA -from common import * -from server import SubsystemHandler -from sftp import * -from sftp_si import * -from sftp_attr import * +from paramiko.common import * +from paramiko.server import SubsystemHandler +from paramiko.sftp import * +from paramiko.sftp_si import * +from paramiko.sftp_attr import * # known hash algorithms for the "check-file" extension @@ -106,7 +106,7 @@ def convert_errno(e): """ - Convert an errno value (as from an C{OSError} or C{IOError} into a + Convert an errno value (as from an C{OSError} or C{IOError}) into a standard SFTP result code. This is a convenience function for trapping exceptions in server code and returning an appropriate result. @@ -118,7 +118,7 @@ if e == errno.EACCES: # permission denied return SFTP_PERMISSION_DENIED - elif e == errno.ENOENT: + elif (e == errno.ENOENT) or (e == errno.ENOTDIR): # no such file return SFTP_NO_SUCH_FILE else: @@ -147,6 +147,8 @@ os.chown(filename, attr.st_uid, attr.st_gid) if attr._flags & attr.FLAG_AMTIME: os.utime(filename, (attr.st_atime, attr.st_mtime)) + if attr._flags & attr.FLAG_SIZE: + open(filename, 'w+').truncate(attr.st_size) set_file_attr = staticmethod(set_file_attr) @@ -184,7 +186,10 @@ def _send_status(self, request_number, code, desc=None): if desc is None: - desc = SFTP_DESC[code] + try: + desc = SFTP_DESC[code] + except IndexError: + desc = 'Unknown' self._response(request_number, CMD_STATUS, code, desc) def _open_folder(self, request_number, path): @@ -246,29 +251,29 @@ self._send_status(request_number, SFTP_FAILURE, 'Block size too small') return - sum = '' + sum_out = '' offset = start while offset < start + length: blocklen = min(block_size, start + length - offset) # don't try to read more than about 64KB at a time chunklen = min(blocklen, 65536) count = 0 - hash = alg.new() + hash_obj = alg.new() while count < blocklen: data = f.read(offset, chunklen) if not type(data) is str: self._send_status(request_number, data, 'Unable to hash file') return - hash.update(data) + hash_obj.update(data) count += len(data) offset += count - sum += hash.digest() + sum_out += hash_obj.digest() msg = Message() msg.add_int(request_number) msg.add_string('check-file') msg.add_string(algname) - msg.add_bytes(sum) + msg.add_bytes(sum_out) self._send_packet(CMD_EXTENDED_REPLY, str(msg)) def _convert_pflags(self, pflags): @@ -412,9 +417,9 @@ if tag == 'check-file': self._check_file(request_number, msg) else: - send._send_status(request_number, SFTP_OP_UNSUPPORTED) + self._send_status(request_number, SFTP_OP_UNSUPPORTED) else: self._send_status(request_number, SFTP_OP_UNSUPPORTED) +from paramiko.sftp_handle import SFTPHandle -from sftp_handle import SFTPHandle ============================================================ --- paramiko/sftp_si.py 41df6e4b221b6e8a79aa43bc4e52a59e170e75ac +++ paramiko/sftp_si.py fc14cccbee4d37e67230139e4e1d45dcb13ac358 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -23,9 +21,11 @@ """ import os -from common import * -from sftp import * +from paramiko.common import * +from paramiko.sftp import * + + class SFTPServerInterface (object): """ This class defines an interface for controlling the behavior of paramiko @@ -36,6 +36,9 @@ SFTP sessions). However, raising an exception will usually cause the SFTP session to abruptly end, so you will usually want to catch exceptions and return an appropriate error code. + + All paths are in string form instead of unicode because not all SFTP + clients & servers obey the requirement that paths be encoded in UTF-8. """ def __init__ (self, server, *largs, **kwargs): @@ -301,4 +304,3 @@ @rtype: int """ return SFTP_OP_UNSUPPORTED - ============================================================ --- paramiko/ssh_exception.py 4a74db486cf4bf8d5480cf850d9caa85efe8d79c +++ paramiko/ssh_exception.py c8757ab9e6c8ed88b69199279d747f64d31d2897 @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -29,12 +27,14 @@ """ pass + class PasswordRequiredException (SSHException): """ Exception raised when a password is needed to unlock a private key file. """ pass + class BadAuthenticationType (SSHException): """ Exception raised when an authentication type (like password) is used, but @@ -53,7 +53,11 @@ def __init__(self, explanation, types): SSHException.__init__(self, explanation) self.allowed_types = types + + def __str__(self): + return SSHException.__str__(self) + ' (allowed_types=%r)' % self.allowed_types + class PartialAuthentication (SSHException): """ An internal exception thrown in the case of partial authentication. ============================================================ --- paramiko/transport.py 28f37067e15b678b878637b19ea550d116038df0 +++ paramiko/transport.py eddb4e639b1ff6966ea2dea4d19102b102d8ebb2 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -17,24 +17,33 @@ # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. """ -L{BaseTransport} handles the core SSH2 protocol. +L{Transport} handles the core SSH2 protocol. """ -import sys, os, string, threading, socket, struct, time +import os +import socket +import string +import struct +import sys +import threading +import time import weakref -from common import * -from ssh_exception import SSHException -from message import Message -from channel import Channel -from sftp_client import SFTPClient -import util -from packet import Packetizer -from rsakey import RSAKey -from dsskey import DSSKey -from kex_group1 import KexGroup1 -from kex_gex import KexGex -from primes import ModulusPack +from paramiko import util +from paramiko.auth_handler import AuthHandler +from paramiko.channel import Channel +from paramiko.common import * +from paramiko.compress import ZlibCompressor, ZlibDecompressor +from paramiko.dsskey import DSSKey +from paramiko.kex_gex import KexGex +from paramiko.kex_group1 import KexGroup1 +from paramiko.message import Message +from paramiko.packet import Packetizer, NeedRekeyException +from paramiko.primes import ModulusPack +from paramiko.rsakey import RSAKey +from paramiko.server import ServerInterface +from paramiko.sftp_client import SFTPClient +from paramiko.ssh_exception import SSHException, BadAuthenticationType # these come from PyCrypt # http://www.amk.ca/python/writing/pycrypt/ @@ -42,7 +51,7 @@ # PyCrypt compiled for Win32 can be downloaded from the HashTar homepage: # http://nitace.bsd.uchicago.edu:8080/hashtar from Crypto.Cipher import Blowfish, AES, DES3 -from Crypto.Hash import SHA, MD5, HMAC +from Crypto.Hash import SHA, MD5 # for thread cleanup @@ -65,10 +74,8 @@ If you try to add an algorithm that paramiko doesn't recognize, C{ValueError} will be raised. If you try to assign something besides a tuple to one of the fields, C{TypeError} will be raised. - - @since: ivysaur """ - __slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', '_transport' ] + __slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ] def __init__(self, transport): self._transport = transport @@ -92,6 +99,9 @@ def _get_kex(self): return self._transport._preferred_kex + + def _get_compression(self): + return self._transport._preferred_compression def _set(self, name, orig, x): if type(x) is list: @@ -99,7 +109,8 @@ if type(x) is not tuple: raise TypeError('expected tuple or list') possible = getattr(self._transport, orig).keys() - if len(filter(lambda n: n not in possible, x)) > 0: + forbidden = filter(lambda n: n not in possible, x) + if len(forbidden) > 0: raise ValueError('unknown cipher') setattr(self._transport, name, x) @@ -114,6 +125,9 @@ def _set_kex(self, x): self._set('_preferred_kex', '_kex_info', x) + + def _set_compression(self, x): + self._set('_preferred_compression', '_compression_info', x) ciphers = property(_get_ciphers, _set_ciphers, None, "Symmetric encryption ciphers") @@ -122,22 +136,27 @@ key_types = property(_get_key_types, _set_key_types, None, "Public-key algorithms") kex = property(_get_kex, _set_kex, None, "Key exchange algorithms") + compression = property(_get_compression, _set_compression, None, + "Compression algorithms") -class BaseTransport (threading.Thread): +class Transport (threading.Thread): """ - Handles protocol negotiation, key exchange, encryption, and the creation - of channels across an SSH session. Basically everything but authentication - is done here. + An SSH Transport attaches to a stream (usually a socket), negotiates an + encrypted session, authenticates, and then creates stream tunnels, called + L{Channel}s, across the session. Multiple channels can be multiplexed + across a single session (and often are, in the case of port forwardings). """ + _PROTO_ID = '2.0' - _CLIENT_ID = 'paramiko_1.4' + _CLIENT_ID = 'paramiko_1.5.4' _preferred_ciphers = ( 'aes128-cbc', 'blowfish-cbc', 'aes256-cbc', '3des-cbc' ) _preferred_macs = ( 'hmac-sha1', 'hmac-md5', 'hmac-sha1-96', 'hmac-md5-96' ) _preferred_keys = ( 'ssh-rsa', 'ssh-dss' ) _preferred_kex = ( 'diffie-hellman-group1-sha1', 'diffie-hellman-group-exchange-sha1' ) - + _preferred_compression = ( 'none', ) + _cipher_info = { 'blowfish-cbc': { 'class': Blowfish, 'mode': Blowfish.MODE_CBC, 'block-size': 8, 'key-size': 16 }, 'aes128-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 16 }, @@ -161,6 +180,15 @@ 'diffie-hellman-group1-sha1': KexGroup1, 'diffie-hellman-group-exchange-sha1': KexGex, } + + _compression_info = { + # address@hidden is just zlib, but only turned on after a successful + # authentication. openssh servers may only offer this type because + # they've had troubles with security holes in zlib in the past. + 'address@hidden': ( ZlibCompressor, ZlibDecompressor ), + 'zlib': ( ZlibCompressor, ZlibDecompressor ), + 'none': ( None, None ), + } _modulus_pack = None @@ -217,31 +245,49 @@ self.sock.settimeout(0.1) except AttributeError: pass + # negotiated crypto parameters self.packetizer = Packetizer(sock) self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID self.remote_version = '' self.local_cipher = self.remote_cipher = '' self.local_kex_init = self.remote_kex_init = None + self.local_mac = self.remote_mac = None + self.local_compression = self.remote_compression = None self.session_id = None - # /negotiated crypto parameters - self.expected_packet = 0 + self.host_key_type = None + self.host_key = None + + # state used during negotiation + self.kex_engine = None + self.H = None + self.K = None + self.active = False self.initial_kex_done = False self.in_kex = False + self.authenticated = False + self.expected_packet = 0 self.lock = threading.Lock() # synchronization (always higher level than write_lock) + + # tracking open channels self.channels = weakref.WeakValueDictionary() # (id -> Channel) self.channel_events = { } # (id -> Event) + self.channels_seen = { } # (id -> True) self.channel_counter = 1 self.window_size = 65536 - self.max_packet_size = 32768 + self.max_packet_size = 34816 + self.saved_exception = None self.clear_to_send = threading.Event() + self.clear_to_send_lock = threading.Lock() self.log_name = 'paramiko.transport' self.logger = util.get_logger(self.log_name) self.packetizer.set_log(self.logger) - # user-defined event callbacks: - self.completion_event = None + self.auth_handler = None + self.global_response = None # response Message from an arbitrary global request + self.completion_event = None # user-defined event callbacks + # server mode: self.server_mode = False self.server_object = None @@ -250,28 +296,43 @@ self.server_accept_cv = threading.Condition(self.lock) self.subsystem_table = { } - def __del__(self): - self.close() - def __repr__(self): """ Returns a string representation of this object, for debugging. @rtype: str """ - out = ' 1: + raise SSHException('Fallback authentication failed.') + if len(fields) == 0: + # for some reason, at least on os x, a 2nd request will + # be made with zero fields requested. maybe it's just + # to try to fake out automated scripting of the exact + # type we're doing here. *shrug* :) + return [] + return [ password ] + return self.auth_interactive(username, handler) + except SSHException, ignored: + # attempt failed; just raise the original exception + raise x + return None + + def auth_publickey(self, username, key, event=None): + """ + Authenticate to the server using a private key. The key is used to + sign data from the server, so it must include the private part. + + If an C{event} is passed in, this method will return immediately, and + the event will be triggered once authentication succeeds or fails. On + success, L{is_authenticated} will return C{True}. On failure, you may + use L{get_exception} to get more detailed error information. + + Since 1.1, if no event is passed, this method will block until the + authentication succeeds or fails. On failure, an exception is raised. + Otherwise, the method simply returns. + + If the server requires multi-step authentication (which is very rare), + this method will return a list of auth types permissible for the next + step. Otherwise, in the normal case, an empty list is returned. + + @param username: the username to authenticate as + @type username: string + @param key: the private key to authenticate with + @type key: L{PKey } + @param event: an event to trigger when the authentication attempt is + complete (whether it was successful or not) + @type event: threading.Event + @return: list of auth types permissible for the next stage of + authentication (normally empty). + @rtype: list + + @raise BadAuthenticationType: if public-key authentication isn't + allowed by the server for this user (and no event was passed in). + @raise SSHException: if the authentication failed (and no event was + passed in). + """ + if (not self.active) or (not self.initial_kex_done): + # we should never try to authenticate unless we're on a secure link + raise SSHException('No existing session') + if event is None: + my_event = threading.Event() + else: + my_event = event + self.auth_handler = AuthHandler(self) + self.auth_handler.auth_publickey(username, key, my_event) + if event is not None: + # caller wants to wait for event themselves + return [] + return self.auth_handler.wait_for_response(my_event) + + def auth_interactive(self, username, handler, submethods=''): + """ + Authenticate to the server interactively. A handler is used to answer + arbitrary questions from the server. On many servers, this is just a + dumb wrapper around PAM. + + This method will block until the authentication succeeds or fails, + peroidically calling the handler asynchronously to get answers to + authentication questions. The handler may be called more than once + if the server continues to ask questions. + + The handler is expected to be a callable that will handle calls of the + form: C{handler(title, instructions, prompt_list)}. The C{title} is + meant to be a dialog-window title, and the C{instructions} are user + instructions (both are strings). C{prompt_list} will be a list of + prompts, each prompt being a tuple of C{(str, bool)}. The string is + the prompt and the boolean indicates whether the user text should be + echoed. + + A sample call would thus be: + C{handler('title', 'instructions', [('Password:', False)])}. + + The handler should return a list or tuple of answers to the server's + questions. + + If the server requires multi-step authentication (which is very rare), + this method will return a list of auth types permissible for the next + step. Otherwise, in the normal case, an empty list is returned. + + @param username: the username to authenticate as + @type username: string + @param handler: a handler for responding to server questions + @type handler: callable + @param submethods: a string list of desired submethods (optional) + @type submethods: str + @return: list of auth types permissible for the next stage of + authentication (normally empty). + @rtype: list + + @raise BadAuthenticationType: if public-key authentication isn't + allowed by the server for this user + @raise SSHException: if the authentication failed + + @since: 1.5 + """ + if (not self.active) or (not self.initial_kex_done): + # we should never try to authenticate unless we're on a secure link + raise SSHException('No existing session') + my_event = threading.Event() + self.auth_handler = AuthHandler(self) + self.auth_handler.auth_interactive(username, handler, my_event, submethods) + return self.auth_handler.wait_for_response(my_event) + def set_log_channel(self, name): """ Set the channel for this transport's logging. The default is @@ -848,6 +1133,7 @@ """ self.log_name = name self.logger = util.get_logger(name) + self.packetizer.set_log(self.logger) def get_log_channel(self): """ @@ -883,6 +1169,39 @@ @since: 1.4 """ return self.packetizer.get_hexdump() + + def use_compression(self, compress=True): + """ + Turn on/off compression. This will only have an affect before starting + the transport (ie before calling L{connect}, etc). By default, + compression is off since it negatively affects interactive sessions + and is not fully tested. + + @param compress: C{True} to ask the remote client/server to compress + traffic; C{False} to refuse compression + @type compress: bool + + @since: 1.5.2 + """ + if compress: + self._preferred_compression = ( 'address@hidden', 'zlib', 'none' ) + else: + self._preferred_compression = ( 'none', ) + + def getpeername(self): + """ + Return the address of the remote side of this Transport, if possible. + This is effectively a wrapper around C{'getpeername'} on the underlying + socket. If the socket-like object has no C{'getpeername'} method, + then C{("unknown", 0)} is returned. + + @return: the address if the remote host, if known + @rtype: tuple(str, int) + """ + gp = getattr(self.sock, 'getpeername', None) + if gp is None: + return ('unknown', 0) + return gp() def stop_thread(self): self.active = False @@ -914,8 +1233,6 @@ def _send_message(self, data): self.packetizer.send_message(data) - if self.packetizer.need_rekey() and not self.in_kex: - self._send_kex_init() def _send_user_message(self, data): """ @@ -927,9 +1244,14 @@ if not self.active: self._log(DEBUG, 'Dropping user packet because connection is dead.') return + self.clear_to_send_lock.acquire() if self.clear_to_send.isSet(): break - self._send_message(data) + self.clear_to_send_lock.release() + try: + self._send_message(data) + finally: + self.clear_to_send_lock.release() def _set_K_H(self, k, h): "used by a kex object to set the K (root key) and H (exchange hash)" @@ -963,9 +1285,9 @@ m.add_mpint(self.K) m.add_bytes(self.H) m.add_bytes(sofar) - hash = SHA.new(str(m)).digest() - out += hash - sofar += hash + digest = SHA.new(str(m)).digest() + out += digest + sofar += digest return out[:nbytes] def _get_cipher(self, name, key, iv): @@ -994,7 +1316,10 @@ while self.active: if self.packetizer.need_rekey() and not self.in_kex: self._send_kex_init() - ptype, m = self.packetizer.read_message() + try: + ptype, m = self.packetizer.read_message() + except NeedRekeyException: + continue if ptype == MSG_IGNORE: continue elif ptype == MSG_DISCONNECT: @@ -1019,10 +1344,14 @@ chanid = m.get_int() if self.channels.has_key(chanid): self._channel_handler_table[ptype](self.channels[chanid], m) + elif self.channels_seen.has_key(chanid): + self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) else: self._log(ERROR, 'Channel request for unknown channel %d' % chanid) self.active = False self.packetizer.close() + elif (self.auth_handler is not None) and self.auth_handler._handler_table.has_key(ptype): + self.auth_handler._handler_table[ptype](self.auth_handler, m) else: self._log(WARNING, 'Oops, unhandled type %d' % ptype) msg = Message() @@ -1056,8 +1385,8 @@ self.packetizer.close() if self.completion_event != None: self.completion_event.set() - if self.auth_event != None: - self.auth_event.set() + if self.auth_handler is not None: + self.auth_handler.abort() for event in self.channel_events.values(): event.set() self.sock.close() @@ -1068,7 +1397,11 @@ def _negotiate_keys(self, m): # throws SSHException on anything unusual - self.clear_to_send.clear() + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.clear() + finally: + self.clear_to_send_lock.release() if self.local_kex_init == None: # remote side wants to renegotiate self._send_kex_init() @@ -1084,24 +1417,24 @@ else: timeout = 2 try: - buffer = self.packetizer.readline(timeout) + buf = self.packetizer.readline(timeout) except Exception, x: raise SSHException('Error reading SSH protocol banner' + str(x)) - if buffer[:4] == 'SSH-': + if buf[:4] == 'SSH-': break - self._log(DEBUG, 'Banner: ' + buffer) - if buffer[:4] != 'SSH-': - raise SSHException('Indecipherable protocol version "' + buffer + '"') + self._log(DEBUG, 'Banner: ' + buf) + if buf[:4] != 'SSH-': + raise SSHException('Indecipherable protocol version "' + buf + '"') # save this server version string for later - self.remote_version = buffer + self.remote_version = buf # pull off any attached comment comment = '' - i = string.find(buffer, ' ') + i = string.find(buf, ' ') if i >= 0: - comment = buffer[i+1:] - buffer = buffer[:i] + comment = buf[i+1:] + buf = buf[:i] # parse out version string and make sure it matches - segs = buffer.split('-', 2) + segs = buf.split('-', 2) if len(segs) < 3: raise SSHException('Invalid SSH banner') version = segs[1] @@ -1115,7 +1448,11 @@ announce to the other side that we'd like to negotiate keys, and what kind of key negotiation we support. """ - self.clear_to_send.clear() + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.clear() + finally: + self.clear_to_send_lock.release() self.in_kex = True if self.server_mode: if (self._modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self._preferred_kex): @@ -1128,6 +1465,7 @@ else: available_server_keys = self._preferred_keys + randpool.stir() m = Message() m.add_byte(chr(MSG_KEXINIT)) m.add_bytes(randpool.get_bytes(16)) @@ -1137,8 +1475,8 @@ m.add_list(self._preferred_ciphers) m.add_list(self._preferred_macs) m.add_list(self._preferred_macs) - m.add_string('none') - m.add_string('none') + m.add_list(self._preferred_compression) + m.add_list(self._preferred_compression) m.add_string('') m.add_string('') m.add_boolean(False) @@ -1162,10 +1500,16 @@ kex_follows = m.get_boolean() unused = m.get_int() - # no compression support (yet?) - if (not('none' in client_compress_algo_list) or - not('none' in server_compress_algo_list)): - raise SSHException('Incompatible ssh peer.') + self._log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ + ' client encrypt:' + str(client_encrypt_algo_list) + \ + ' server encrypt:' + str(server_encrypt_algo_list) + \ + ' client mac:' + str(client_mac_algo_list) + \ + ' server mac:' + str(server_mac_algo_list) + \ + ' client compress:' + str(client_compress_algo_list) + \ + ' server compress:' + str(server_compress_algo_list) + \ + ' client lang:' + str(client_lang_list) + \ + ' server lang:' + str(server_lang_list) + \ + ' kex follows?' + str(kex_follows)) # as a server, we pick the first item in the client's list that we support. # as a client, we pick the first item in our list that the server supports. @@ -1216,19 +1560,20 @@ self.local_mac = agreed_local_macs[0] self.remote_mac = agreed_remote_macs[0] - self._log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ - ' client encrypt:' + str(client_encrypt_algo_list) + \ - ' server encrypt:' + str(server_encrypt_algo_list) + \ - ' client mac:' + str(client_mac_algo_list) + \ - ' server mac:' + str(server_mac_algo_list) + \ - ' client compress:' + str(client_compress_algo_list) + \ - ' server compress:' + str(server_compress_algo_list) + \ - ' client lang:' + str(client_lang_list) + \ - ' server lang:' + str(server_lang_list) + \ - ' kex follows?' + str(kex_follows)) - self._log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' % + if self.server_mode: + agreed_remote_compression = filter(self._preferred_compression.__contains__, client_compress_algo_list) + agreed_local_compression = filter(self._preferred_compression.__contains__, server_compress_algo_list) + else: + agreed_local_compression = filter(client_compress_algo_list.__contains__, self._preferred_compression) + agreed_remote_compression = filter(server_compress_algo_list.__contains__, self._preferred_compression) + if (len(agreed_local_compression) == 0) or (len(agreed_remote_compression) == 0): + raise SSHException('Incompatible ssh server (no acceptable compression) %r %r %r' % (agreed_local_compression, agreed_remote_compression, self._preferred_compression)) + self.local_compression = agreed_local_compression[0] + self.remote_compression = agreed_remote_compression[0] + + self._log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s; compression: local %s, remote %s' % (agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac, - self.remote_mac)) + self.remote_mac, self.local_compression, self.remote_compression)) # save for computing hash later... # now wait! openssh has a bug (and others might too) where there are @@ -1256,6 +1601,10 @@ else: mac_key = self._compute_key('F', mac_engine.digest_size) self.packetizer.set_inbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) + compress_in = self._compression_info[self.remote_compression][1] + if (compress_in is not None) and ((self.remote_compression != 'address@hidden') or self.authenticated): + self._log(DEBUG, 'Switching on inbound compression ...') + self.packetizer.set_inbound_compressor(compress_in()) def _activate_outbound(self): "switch on newly negotiated encryption parameters for outbound traffic" @@ -1279,11 +1628,27 @@ else: mac_key = self._compute_key('E', mac_engine.digest_size) self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) + compress_out = self._compression_info[self.local_compression][0] + if (compress_out is not None) and ((self.local_compression != 'address@hidden') or self.authenticated): + self._log(DEBUG, 'Switching on outbound compression ...') + self.packetizer.set_outbound_compressor(compress_out()) if not self.packetizer.need_rekey(): self.in_kex = False # we always expect to receive NEWKEYS now self.expected_packet = MSG_NEWKEYS + def _auth_trigger(self): + self.authenticated = True + # delayed initiation of compression + if self.local_compression == 'address@hidden': + compress_out = self._compression_info[self.local_compression][0] + self._log(DEBUG, 'Switching on outbound compression ...') + self.packetizer.set_outbound_compressor(compress_out()) + if self.remote_compression == 'address@hidden': + compress_in = self._compression_info[self.remote_compression][1] + self._log(DEBUG, 'Switching on inbound compression ...') + self.packetizer.set_inbound_compressor(compress_in()) + def _parse_newkeys(self, m): self._log(DEBUG, 'Switch to new keys ...') self._activate_inbound() @@ -1291,6 +1656,9 @@ self.local_kex_init = self.remote_kex_init = None self.K = None self.kex_engine = None + if self.server_mode and (self.auth_handler is None): + # create auth handler for server mode + self.auth_handler = AuthHandler(self) if not self.initial_kex_done: # this was the first key exchange self.initial_kex_done = True @@ -1300,7 +1668,11 @@ # it's now okay to send data again (if this was a re-key) if not self.packetizer.need_rekey(): self.in_kex = False - self.clear_to_send.set() + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.set() + finally: + self.clear_to_send_lock.release() return def _parse_disconnect(self, m): @@ -1416,6 +1788,7 @@ try: self.lock.acquire() self.channels[my_chanid] = chan + self.channels_seen[my_chanid] = True chan._set_transport(self) chan._set_window(self.window_size, self.max_packet_size) chan._set_remote_channel(chanid, initial_window_size, max_packet_size) @@ -1472,5 +1845,3 @@ MSG_CHANNEL_EOF: Channel._handle_eof, MSG_CHANNEL_CLOSE: Channel._handle_close, } - -from server import ServerInterface ============================================================ --- paramiko/util.py 368b5dc43be94fdc356a45c5aab43065771e76a4 +++ paramiko/util.py 4b7ebbb619b6ef6e46b4287d1b78d8a09b87ba99 @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2006 Robey Pointer # # This file is part of paramiko. # @@ -16,15 +16,21 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -from __future__ import generators - """ Useful functions used by the rest of paramiko. """ -import sys, struct, traceback, threading -from common import * +from __future__ import generators +import fnmatch +import sys +import struct +import traceback +import threading + +from paramiko.common import * + + # Change by RogerB - python < 2.3 doesn't have enumerate so we implement it if sys.version_info < (2,3): class enumerate: @@ -36,6 +42,7 @@ yield (count, item) count += 1 + def inflate_long(s, always_positive=False): "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" out = 0L @@ -161,12 +168,12 @@ if len(salt) > 8: salt = salt[:8] while nbytes > 0: - hash = hashclass.new() + hash_obj = hashclass.new() if len(digest) > 0: - hash.update(digest) - hash.update(key) - hash.update(salt) - digest = hash.digest() + hash_obj.update(digest) + hash_obj.update(key) + hash_obj.update(salt) + digest = hash_obj.digest() size = min(nbytes, len(digest)) keydata += digest[:size] nbytes -= size @@ -182,36 +189,98 @@ This type of file unfortunately doesn't exist on Windows, but on posix, it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}. + Since 1.5.3, this is just a wrapper around L{HostKeys}. + @param filename: name of the file to read host keys from @type filename: str @return: dict of host keys, indexed by hostname and then keytype @rtype: dict(hostname, dict(keytype, L{PKey })) """ - import base64 - from rsakey import RSAKey - from dsskey import DSSKey - - keys = {} - f = file(filename, 'r') - for line in f: - line = line.strip() - if (len(line) == 0) or (line[0] == '#'): + from paramiko.hostkeys import HostKeys + return HostKeys(filename) + +def parse_ssh_config(file_obj): + """ + Parse a config file of the format used by OpenSSH, and return an object + that can be used to make queries to L{lookup_ssh_host_config}. The + format is described in OpenSSH's C{ssh_config} man page. This method is + provided primarily as a convenience to posix users (since the OpenSSH + format is a de-facto standard on posix) but should work fine on Windows + too. + + The return value is currently a list of dictionaries, each containing + host-specific configuration, but this is considered an implementation + detail and may be subject to change in later versions. + + @param file_obj: a file-like object to read the config file from + @type file_obj: file + @return: opaque configuration object + @rtype: object + """ + ret = [] + config = { 'host': '*' } + ret.append(config) + + for line in file_obj: + line = line.rstrip('\n').lstrip() + if (line == '') or (line[0] == '#'): continue - keylist = line.split(' ') - if len(keylist) != 3: - continue - hostlist, keytype, key = keylist - hosts = hostlist.split(',') - for host in hosts: - if not keys.has_key(host): - keys[host] = {} - if keytype == 'ssh-rsa': - keys[host][keytype] = RSAKey(data=base64.decodestring(key)) - elif keytype == 'ssh-dss': - keys[host][keytype] = DSSKey(data=base64.decodestring(key)) - f.close() - return keys + if '=' in line: + key, value = line.split('=', 1) + key = key.strip().lower() + else: + # find first whitespace, and split there + i = 0 + while (i < len(line)) and not line[i].isspace(): + i += 1 + if i == len(line): + raise Exception('Unparsable line: %r' % line) + key = line[:i].lower() + value = line[i:].lstrip() + if key == 'host': + # do we have a pre-existing host config to append to? + matches = [c for c in ret if c['host'] == value] + if len(matches) > 0: + config = matches[0] + else: + config = { 'host': value } + ret.append(config) + else: + config[key] = value + + return ret + +def lookup_ssh_host_config(hostname, config): + """ + Return a dict of config options for a given hostname. The C{config} object + must come from L{parse_ssh_config}. + + The host-matching rules of OpenSSH's C{ssh_config} man page are used, which + means that all configuration options from matching host specifications are + merged, with more specific hostmasks taking precedence. In other words, if + C{"Port"} is set under C{"Host *"} and also C{"Host *.example.com"}, and + the lookup is for C{"ssh.example.com"}, then the port entry for + C{"Host *.example.com"} will win out. + + The keys in the returned dict are all normalized to lowercase (look for + C{"port"}, not C{"Port"}. No other processing is done to the keys or + values. + + @param hostname: the hostname to lookup + @type hostname: str + @param config: the config object to search + @type config: object + """ + matches = [x for x in config if fnmatch.fnmatch(hostname, x['host'])] + # sort in order of shortest match (usually '*') to longest + matches.sort(lambda x,y: cmp(len(x['host']), len(y['host']))) + ret = {} + for m in matches: + ret.update(m) + del ret['host'] + return ret + def mod_inverse(x, m): # it's crazy how small python can make this function. u1, u2, u3 = 1, 0, m @@ -226,21 +295,21 @@ u2 += m return u2 -g_thread_ids = {} -g_thread_counter = 0 -g_thread_lock = threading.Lock() +_g_thread_ids = {} +_g_thread_counter = 0 +_g_thread_lock = threading.Lock() def get_thread_id(): - global g_thread_ids, g_thread_counter + global _g_thread_ids, _g_thread_counter, _g_thread_lock tid = id(threading.currentThread()) try: - return g_thread_ids[tid] + return _g_thread_ids[tid] except KeyError: - g_thread_lock.acquire() + _g_thread_lock.acquire() try: - g_thread_counter += 1 - ret = g_thread_ids[tid] = g_thread_counter + _g_thread_counter += 1 + ret = _g_thread_ids[tid] = _g_thread_counter finally: - g_thread_lock.release() + _g_thread_lock.release() return ret def log_to_file(filename, level=DEBUG): @@ -251,15 +320,20 @@ l.setLevel(level) f = open(filename, 'w') lh = logging.StreamHandler(f) - lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s] thr=%(_threadid)-3d %(name)s: %(message)s', + lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s', '%Y%m%d-%H:%M:%S')) l.addHandler(lh) +# make only one filter object, so it doesn't get applied more than once +class PFilter (object): + def filter(self, record): + record._threadid = get_thread_id() + return True +_pfilter = PFilter() + def get_logger(name): l = logging.getLogger(name) - class PFilter (object): - def filter(self, record): - record._threadid = get_thread_id() - return True - l.addFilter(PFilter()) + l.addFilter(_pfilter) return l + +