From e620b9e8d52766db10618a125ad611af14e04a09 Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 00:13:31 +0200 Subject: [PATCH 1/9] Add pool of SFTP sessions --- .gitignore | 1 + fs/sshfs/pool.py | 96 ++++++++++++++++++++++++++++++++++++++++++ fs/sshfs/sshfs.py | 91 +++++++++++++++++++++++---------------- tests/requirements.txt | 2 +- tests/test_sshfs.py | 43 +++++++++++-------- tests/utils.py | 4 +- 6 files changed, 180 insertions(+), 57 deletions(-) create mode 100644 fs/sshfs/pool.py diff --git a/.gitignore b/.gitignore index bd59574..65f0105 100644 --- a/.gitignore +++ b/.gitignore @@ -79,6 +79,7 @@ celerybeat-schedule .env # virtualenv +.venv/ venv/ ENV/ diff --git a/fs/sshfs/pool.py b/fs/sshfs/pool.py new file mode 100644 index 0000000..b5768f5 --- /dev/null +++ b/fs/sshfs/pool.py @@ -0,0 +1,96 @@ + +import contextlib +from queue import Queue +from threading import RLock, local +from paramiko import SSHClient, SFTPClient + +_local = local() + +class ConnectionPool(object): + + def __init__(self, open_func, close_func, healty_func, max_connections=4, timeout=None): + """ + A generic connection pool. + + :param open_func: Function that opens a new connection. + :param close_func: Function that closes a given connection. + :param healty_func: Function to check if an active connection is still healty. + :param max_connections: Maximum number of open connections + :param timeout: Maximum time to wait for a connection + """ + + self._open_func = open_func + self._close_func = close_func + self._healthy = healty_func + self.timeout = timeout + self._lock = RLock() + # self.unused_timeout = unused_timeout :param unused_timeout: Time after which an unused connection should be closed + + self.active = 0 + self._q = Queue(maxsize=max_connections) + + @contextlib.contextmanager + def connection(self): + if getattr(_local, "conn", None) is None: + try: + _local.conn = self.acquire() + yield _local.conn + finally: + self.release(_local.conn) + _local.conn = None + else: + yield _local.conn + + + def _open(self): + self.active += 1 + return self._open_func() + + def _close(self, conn): + self.active -= 1 + self._close_func(conn) + + def acquire(self): + while True: + with self._lock: + if self._q.empty() and self.active < self._q.maxsize: + return self._open() + + conn = self._q.get(timeout=self.timeout) + if not self._healthy(conn): + self._close() + continue + + return conn + + def release(self, conn): + self._q.put(conn, block=False) + +class SFTPClientPool(ConnectionPool): + + def __init__(self, client:SSHClient, max_connections:int=4, timeout=None): + + def open_sftp(): + return client.open_sftp() + + def close_sftp(sftp:SFTPClient): + sftp.close() + + def is_healthy(sftp:SFTPClient) -> bool: + return not sftp.get_channel().closed + + super().__init__(open_sftp, close_sftp, is_healthy, max_connections, timeout) + + def acquire(self) -> SFTPClient: + conn: SFTPClient = super().acquire() + channel = conn.get_channel() + transport = channel.get_transport() + + if channel.closed: + raise Exception("Channel is closed") + + if not transport.is_active(): + raise Exception("Transport is closed") + + return conn + diff --git a/fs/sshfs/sshfs.py b/fs/sshfs/sshfs.py index aaa4cda..8eb4bfd 100644 --- a/fs/sshfs/sshfs.py +++ b/fs/sshfs/sshfs.py @@ -25,6 +25,7 @@ from .file import SSHFile from .error_tools import convert_sshfs_errors +from .pool import SFTPClientPool class SSHFS(FS): @@ -145,7 +146,7 @@ def __init__( if keepalive > 0: client.get_transport().set_keepalive(keepalive) - self._sftp = client.open_sftp() + self._pool = SFTPClientPool(client) except (paramiko.ssh_exception.SSHException, # protocol errors paramiko.ssh_exception.NoValidConnectionsError, # connexion errors @@ -166,17 +167,21 @@ def close(self): # noqa: D102 self._client.close() super().close() + + def _sftp(self): + return self._pool.connection() + def getinfo(self, path, namespaces=None): # noqa: D102 self.check() namespaces = namespaces or () _path = self.validatepath(path) - with convert_sshfs_errors('getinfo', path): - _stat = self._sftp.stat(_path) + with convert_sshfs_errors('getinfo', path), self._sftp() as sftp: + _stat = sftp.stat(_path) info = self._make_raw_info(basename(_path), _stat, namespaces) if "lstat" in namespaces or "link" in namespaces: - _lstat = self._sftp.lstat(_path) + _lstat = sftp.lstat(_path) if "lstat" in namespaces: info["lstat"] = { k: getattr(_lstat, k) @@ -185,7 +190,7 @@ def getinfo(self, path, namespaces=None): # noqa: D102 } if "link" in namespaces: if OSFS._get_type_from_stat(_lstat) == ResourceType.symlink: - target = self._sftp.readlink(_path) + target = sftp.readlink(_path) info["link"] = {"target": target} else: info["link"] = {"target": None} @@ -210,8 +215,8 @@ def listdir(self, path): # noqa: D102 if _type is not ResourceType.directory: raise errors.DirectoryExpected(path) - with convert_sshfs_errors('listdir', path): - return self._sftp.listdir(_path) + with convert_sshfs_errors('listdir', path), self._sftp() as sftp: + return sftp.listdir(_path) def scandir(self, path, namespaces=None, page=None): # noqa: D102 self.check() @@ -219,11 +224,11 @@ def scandir(self, path, namespaces=None, page=None): # noqa: D102 _namespaces = namespaces or () start, stop = page or (None, None) try: - with convert_sshfs_errors('scandir', path, directory=True): + with convert_sshfs_errors('scandir', path, directory=True), self._sftp() as sftp: # We can't use listdir_iter here because it doesn't support # concurrent iteration over multiple directories, which can # happen during a search="depth" walk. - listing = self._sftp.listdir_attr(_path) + listing = sftp.listdir_attr(_path) for _stat in itertools.islice(listing, start, stop): yield Info(self._make_raw_info(_stat.filename, _stat, _namespaces)) except errors.ResourceNotFound: @@ -244,8 +249,8 @@ def makedir(self, path, permissions=None, recreate=False): # noqa: D102 info = self.getinfo(_path) except errors.ResourceNotFound: with self._lock: - with convert_sshfs_errors('makedir', path): - self._sftp.mkdir(_path, _permissions.mode) + with convert_sshfs_errors('makedir', path), self._sftp() as sftp: + sftp.mkdir(_path, _permissions.mode) else: if (info.is_dir and not recreate) or info.is_file: six.raise_from(errors.DirectoryExists(path), None) @@ -259,7 +264,7 @@ def move(self, src_path, dst_path, overwrite=False, preserve_time=False): _src_path = self.validatepath(src_path) _dst_path = self.validatepath(dst_path) - with self._lock: + with self._lock, self._sftp() as sftp: # check src exists and is a file src_info = self.getinfo(_src_path, namespaces=info_ns) if src_info.is_dir: @@ -274,9 +279,9 @@ def move(self, src_path, dst_path, overwrite=False, preserve_time=False): if not overwrite: raise errors.DestinationExists(dst_path) with convert_sshfs_errors('move', dst_path): - self._sftp.remove(_dst_path) + sftp.remove(_dst_path) # rename the file through SFTP's 'RENAME' - self._sftp.rename(_src_path, _dst_path) + sftp.rename(_src_path, _dst_path) # preserve times if required if preserve_time: self._utime( @@ -329,12 +334,21 @@ def openbin(self, path, mode='r', buffering=-1, **options): # noqa: D102 elif self.isdir(_path): raise errors.FileExpected(path) with convert_sshfs_errors('openbin', path): - _sftp = self._client.open_sftp() - handle = _sftp.open( + sftp = self._pool.acquire() + handle = sftp.open( _path, mode=_mode.to_platform_bin(), bufsize=buffering ) + + # release sftp client on close + def wrap_close(fn): + def close(): + fn() + self._pool.release(sftp) + return close + handle.close = wrap_close(handle.close) + handle.set_pipelined(options.get("pipelined", True)) if options.get("prefetch", True): if _mode.reading and not _mode.writing: @@ -350,9 +364,9 @@ def remove(self, path): # noqa: D102 if self.getinfo(_path).is_dir: raise errors.FileExpected(path) - with convert_sshfs_errors('remove', path): + with convert_sshfs_errors('remove', path), self._sftp() as sftp: with self._lock: - self._sftp.remove(_path) + sftp.remove(_path) def removedir(self, path): # noqa: D102 self.check() @@ -364,9 +378,8 @@ def removedir(self, path): # noqa: D102 if not self.isempty(_path): raise errors.DirectoryNotEmpty(path) - with convert_sshfs_errors('removedir', path): - with self._lock: - self._sftp.rmdir(_path) + with convert_sshfs_errors('removedir', path), self._lock, self._sftp() as sftp: + sftp.rmdir(_path) def setinfo(self, path, info): # noqa: D102 self.check() @@ -423,8 +436,8 @@ def download(self, path, file, chunk_size=None, callback=None, **options): raise errors.ResourceNotFound(path) elif self.isdir(_path): raise errors.FileExpected(path) - with convert_sshfs_errors('download', path): - self._sftp.getfo(_path, file, callback=callback) + with convert_sshfs_errors('download', path), self._sftp() as sftp: + sftp.getfo(_path, file, callback=callback) def upload(self, path, file, chunk_size=None, callback=None, file_size=None, confirm=True, **options): """Set a file to the contents of a binary file object. @@ -466,8 +479,8 @@ def upload(self, path, file, chunk_size=None, callback=None, file_size=None, con raise errors.ResourceNotFound(path) elif self.isdir(_path): raise errors.FileExpected(path) - with convert_sshfs_errors('upload', path): - self._sftp.putfo( + with convert_sshfs_errors('upload', path), self._sftp() as sftp: + sftp.putfo( file, _path, file_size=file_size, @@ -580,23 +593,27 @@ def entry_name(db, _id): def _chmod(self, path, mode): """Change the *mode* of a resource. """ - self._sftp.chmod(path, mode) + with self._sftp() as sftp: + sftp.chmod(path, mode) def _chown(self, path, uid, gid): """Change the *owner* of a resource. """ - if uid is None or gid is None: - info = self.getinfo(path, namespaces=('access',)) - uid = uid or info.get('access', 'uid') - gid = gid or info.get('access', 'gid') - self._sftp.chown(path, uid, gid) + with self._sftp() as sftp: + if uid is None or gid is None: + info = self.getinfo(path, namespaces=('access',)) + uid = uid or info.get('access', 'uid') + gid = gid or info.get('access', 'gid') + + sftp.chown(path, uid, gid) def _utime(self, path, modified, accessed): """Set the *accessed* and *modified* times of a resource. """ - if accessed is not None or modified is not None: - accessed = float(modified if accessed is None else accessed) - modified = float(accessed if modified is None else modified) - self._sftp.utime(path, (accessed, modified)) - else: - self._sftp.utime(path, None) + with self._sftp() as sftp: + if accessed is not None or modified is not None: + accessed = float(modified if accessed is None else accessed) + modified = float(accessed if modified is None else modified) + sftp.utime(path, (accessed, modified)) + else: + sftp.utime(path, None) diff --git a/tests/requirements.txt b/tests/requirements.txt index 3053826..02c0e1a 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,5 +1,5 @@ codecov ~=2.1 -docker ~=4.4 +docker ~=7.1.0 urllib3 <2 semantic-version ~=2.6 mock ~=3.0.5 ; python_version < '3.3' diff --git a/tests/test_sshfs.py b/tests/test_sshfs.py index 77d48c2..01b89e5 100644 --- a/tests/test_sshfs.py +++ b/tests/test_sshfs.py @@ -45,7 +45,7 @@ def destroy_fs(fs): def make_fs(self): self.ssh_fs = SSHFS('localhost', self.user, self.pasw, port=self.port) - self.test_folder = fs.path.join('/home', self.user, uuid.uuid4().hex) + self.test_folder = fs.path.join('/config', uuid.uuid4().hex) self.ssh_fs.makedir(self.test_folder, recreate=True) return self.ssh_fs.opendir(self.test_folder, factory=ClosingSubFS) @@ -81,6 +81,9 @@ def test_upload_2(self): def test_upload_4(self): super(TestSSHFS, self).test_upload_4() + def _sftp(self): + return self.fs.delegate_fs()._sftp() + def test_chmod(self): self.fs.touch("test.txt") remote_path = fs.path.join(self.test_folder, "test.txt") @@ -88,14 +91,16 @@ def test_chmod(self): # Initial permissions info = self.fs.getinfo("test.txt", ["access"]) self.assertEqual(info.permissions.mode, 0o644) - st = self.fs.delegate_fs()._sftp.stat(remote_path) + with self._sftp() as sftp: + st = sftp.stat(remote_path) self.assertEqual(stat.S_IMODE(st.st_mode), 0o644) # Change permissions with SSHFS._chown self.fs.delegate_fs()._chmod(remote_path, 0o744) info = self.fs.getinfo("test.txt", ["access"]) self.assertEqual(info.permissions.mode, 0o744) - st = self.fs.delegate_fs()._sftp.stat(remote_path) + with self._sftp() as sftp: + st = sftp.stat(remote_path) self.assertEqual(stat.S_IMODE(st.st_mode), 0o744) # Change permissions with SSHFS.setinfo @@ -103,7 +108,8 @@ def test_chmod(self): {"access": {"permissions": Permissions(mode=0o600)}}) info = self.fs.getinfo("test.txt", ["access"]) self.assertEqual(info.permissions.mode, 0o600) - st = self.fs.delegate_fs()._sftp.stat(remote_path) + with self._sftp() as sftp: + st = sftp.stat(remote_path) self.assertEqual(stat.S_IMODE(st.st_mode), 0o600) with self.assertRaises(fs.errors.PermissionDenied): @@ -111,6 +117,7 @@ def test_chmod(self): "access": {"permissions": Permissions(mode=0o777)} }) + def test_chown(self): self.fs.touch("test.txt") @@ -118,18 +125,19 @@ def test_chown(self): info = self.fs.getinfo("test.txt", namespaces=["access"]) gid, uid = info.get('access', 'uid'), info.get('access', 'gid') - with utils.mock.patch.object(self.fs.delegate_fs()._sftp, 'chown') as chown: - self.fs.setinfo("test.txt", {'access': {'uid': None}}) - chown.assert_called_with(remote_path, uid, gid) + with self._sftp() as sftp: + with utils.mock.patch.object(sftp, 'chown') as chown: + self.fs.setinfo("test.txt", {'access': {'uid': None}}) + chown.assert_called_with(remote_path, uid, gid) - self.fs.setinfo("test.txt", {'access': {'gid': None}}) - chown.assert_called_with(remote_path, uid, gid) + self.fs.setinfo("test.txt", {'access': {'gid': None}}) + chown.assert_called_with(remote_path, uid, gid) - self.fs.setinfo("test.txt", {'access': {'gid': 8000}}) - chown.assert_called_with(remote_path, uid, 8000) + self.fs.setinfo("test.txt", {'access': {'gid': 8000}}) + chown.assert_called_with(remote_path, uid, 8000) - self.fs.setinfo("test.txt", {'access': {'uid': 1001, 'gid':1002}}) - chown.assert_called_with(remote_path, 1001, 1002) + self.fs.setinfo("test.txt", {'access': {'uid': 1001, 'gid':1002}}) + chown.assert_called_with(remote_path, 1001, 1002) def test_exec_command_exception(self): ssh = self.fs.delegate_fs() @@ -175,10 +183,11 @@ def test_symlinks(self): with self.fs.openbin("foo", "wb") as f: f.write(b"foobar") - self.fs.delegate_fs()._sftp.symlink( - fs.path.join(self.test_folder, "foo"), - fs.path.join(self.test_folder, "bar") - ) + with self._sftp() as sftp: + sftp.symlink( + fs.path.join(self.test_folder, "foo"), + fs.path.join(self.test_folder, "bar") + ) # os.symlink(self._get_real_path("foo"), self._get_real_path("bar")) self.assertFalse(self.fs.islink("foo")) diff --git a/tests/utils.py b/tests/utils.py index 10d0f13..197c183 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,8 +29,8 @@ def fs_version(): def startServer(docker_client, user, pasw, port): sftp_container = docker_client.containers.run( - "sjourdan/alpine-sshd", detach=True, ports={'22/tcp': port}, - environment={'USER': user, 'PASSWORD': pasw}, + "lscr.io/linuxserver/openssh-server", detach=True, ports={'2222/tcp': port}, + environment={'USER_NAME': user, 'USER_PASSWORD': pasw, "PASSWORD_ACCESS": "true"}, ) time.sleep(1) return sftp_container From e63cbf9bbe7e356902dec9deacec090b8b7fcc6e Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 11:06:30 +0200 Subject: [PATCH 2/9] Simplify pool --- fs/sshfs/pool.py | 77 ++++++++++++++++++------------------------------ 1 file changed, 28 insertions(+), 49 deletions(-) diff --git a/fs/sshfs/pool.py b/fs/sshfs/pool.py index b5768f5..6594cec 100644 --- a/fs/sshfs/pool.py +++ b/fs/sshfs/pool.py @@ -1,36 +1,36 @@ import contextlib from queue import Queue -from threading import RLock, local +from threading import local from paramiko import SSHClient, SFTPClient _local = local() class ConnectionPool(object): - def __init__(self, open_func, close_func, healty_func, max_connections=4, timeout=None): + def __init__(self, open_func, max_connections=4, timeout=None): """ A generic connection pool. :param open_func: Function that opens a new connection. - :param close_func: Function that closes a given connection. - :param healty_func: Function to check if an active connection is still healty. :param max_connections: Maximum number of open connections :param timeout: Maximum time to wait for a connection """ self._open_func = open_func - self._close_func = close_func - self._healthy = healty_func self.timeout = timeout - self._lock = RLock() - # self.unused_timeout = unused_timeout :param unused_timeout: Time after which an unused connection should be closed - self.active = 0 self._q = Queue(maxsize=max_connections) + while not self._q.full(): + self._q.put(None) + @contextlib.contextmanager def connection(self): + ''' + Returns a ContextManager with an open connection. + This function returns the same connection when called recursively. + ''' if getattr(_local, "conn", None) is None: try: _local.conn = self.acquire() @@ -41,27 +41,9 @@ def connection(self): else: yield _local.conn - - def _open(self): - self.active += 1 - return self._open_func() - - def _close(self, conn): - self.active -= 1 - self._close_func(conn) - def acquire(self): - while True: - with self._lock: - if self._q.empty() and self.active < self._q.maxsize: - return self._open() - - conn = self._q.get(timeout=self.timeout) - if not self._healthy(conn): - self._close() - continue - - return conn + conn = self._q.get(timeout=self.timeout) + return self._open_func(conn) def release(self, conn): self._q.put(conn, block=False) @@ -69,28 +51,25 @@ def release(self, conn): class SFTPClientPool(ConnectionPool): def __init__(self, client:SSHClient, max_connections:int=4, timeout=None): - - def open_sftp(): - return client.open_sftp() - - def close_sftp(sftp:SFTPClient): - sftp.close() - - def is_healthy(sftp:SFTPClient) -> bool: - return not sftp.get_channel().closed - - super().__init__(open_sftp, close_sftp, is_healthy, max_connections, timeout) + """ + A pool of SFTPClient sessions - def acquire(self) -> SFTPClient: - conn: SFTPClient = super().acquire() - channel = conn.get_channel() - transport = channel.get_transport() + :param max_connections: Maximum number of open sessions + :param timeout: Maximum time to wait for a session + """ - if channel.closed: - raise Exception("Channel is closed") + def open_sftp(conn: SFTPClient | None): + if conn is None or conn.get_channel().closed: + return client.open_sftp() + return conn - if not transport.is_active(): - raise Exception("Transport is closed") + super().__init__(open_sftp, max_connections, timeout) - return conn + def acquire(self) -> SFTPClient: + """ + Acquire an SFTPClient. + If timeout is None this functions blocks until a free session is available. + Otherwise an Empty exception is raised. + """ + return super().acquire() From 5589292e74b9e3ccfa15b7eb153e6495c64a3cbf Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 11:07:16 +0200 Subject: [PATCH 3/9] Add connection pool parameters to SSHFS --- fs/sshfs/sshfs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fs/sshfs/sshfs.py b/fs/sshfs/sshfs.py index 8eb4bfd..87d8555 100644 --- a/fs/sshfs/sshfs.py +++ b/fs/sshfs/sshfs.py @@ -59,6 +59,10 @@ class SSHFS(FS): policy (paramiko.MissingHostKeyPolicy): The policy to use to resolve missing host keys. Defaults to ``None``, which will create a `paramiko.AutoAddPolicy` instance. + max_connections (int): Maximum number of concurrent SFTPClient sessions + (defaults to 4) + conn_timeout (int): The maximum time to wait for a free session. + Defaults to the value of ``timeout``. Raises: fs.errors.CreateFailed: when the filesystem could not be created. The @@ -105,6 +109,8 @@ def __init__( config_path='~/.ssh/config', exec_timeout=None, policy=None, + max_connections=4, + conn_timeout=None, **kwargs ): # noqa: D102 super(SSHFS, self).__init__() @@ -122,6 +128,7 @@ def __init__( self._client = client = paramiko.SSHClient() self._timeout = timeout self._exec_timeout = timeout if exec_timeout is None else exec_timeout + self._conn_timeout = timeout if conn_timeout is None else conn_timeout _policy = paramiko.AutoAddPolicy() if policy is None else policy @@ -146,7 +153,7 @@ def __init__( if keepalive > 0: client.get_transport().set_keepalive(keepalive) - self._pool = SFTPClientPool(client) + self._pool = SFTPClientPool(client, max_connections, timeout=self._conn_timeout) except (paramiko.ssh_exception.SSHException, # protocol errors paramiko.ssh_exception.NoValidConnectionsError, # connexion errors From 0e1754f007efd4973ffeac36341003e59e43c41d Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 11:07:28 +0200 Subject: [PATCH 4/9] Update tests --- tests/test_opener.py | 22 ++++++++++++---------- tests/test_sshfs.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/test_opener.py b/tests/test_opener.py index ad516df..03568e5 100644 --- a/tests/test_opener.py +++ b/tests/test_opener.py @@ -45,13 +45,14 @@ def tearDownClass(cls): @classmethod def addKeyToServer(cls, pkey): + home = '/config' with paramiko.SSHClient() as client: client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect('localhost', cls.port, cls.user, cls.pasw) with client.open_sftp() as sftp: - if not '.ssh' in sftp.listdir('/home/{}'.format(cls.user)): - sftp.mkdir('/home/{}/.ssh'.format(cls.user)) - with sftp.open('/home/{}/.ssh/authorized_keys'.format(cls.user), 'w') as f: + if '.ssh' not in sftp.listdir(home): + sftp.mkdir('{}/.ssh'.format(home)) + with sftp.open('{}/.ssh/authorized_keys'.format(home), 'w') as f: f.write("{} {}\n".format( pkey.get_name(), pkey.get_base64()).encode('utf-8')) @@ -72,7 +73,7 @@ def tearDown(self): os.remove(self.config_file) def assertFunctional(self, ssh_fs): - test_folder = '/home/{}/{}'.format(self.user, uuid.uuid4().hex) + test_folder = '/config/{}'.format(uuid.uuid4().hex) ssh_fs.makedir(test_folder) with ssh_fs.opendir(test_folder) as test_fs: @@ -141,7 +142,7 @@ def test_sshconfig_notfound(self): def test_create(self): - directory = fs.path.join("home", self.user, "test", "directory") + directory = fs.path.join("config", "test", "directory") base = "ssh://{}:{}@localhost:{}".format(self.user, self.pasw, self.port) url = "{}/{}".format(base, directory) @@ -167,20 +168,21 @@ def test_create(self): self.assertTrue(ssh_fs.isfile("foo")) def test_open_symlink(self): + home = '/config' # create a symlink in the home directory with paramiko.SSHClient() as client: client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect('localhost', self.port, self.user, self.pasw) with client.open_sftp() as sftp: - sftp.mkdir('/home/{}/directory'.format(self.user)) + sftp.mkdir('{}/directory'.format(home)) sftp.symlink( - '/home/{}/directory'.format(self.user), - '/home/{}/link'.format(self.user) + '{}/directory'.format(home), + '{}/link'.format(home) ) - sftp.open('/home/{}/directory/test'.format(self.user), 'w').close() + sftp.open('{}/directory/test'.format(home), 'w').close() # check the symlink can be opened - directory = fs.path.join("home", self.user, "link") + directory = fs.path.join(home, "link") base = "ssh://{}:{}@localhost:{}".format(self.user, self.pasw, self.port) url = "{}/{}".format(base, directory) with fs.open_fs(url) as ssh_fs: diff --git a/tests/test_sshfs.py b/tests/test_sshfs.py index 01b89e5..38dc80d 100644 --- a/tests/test_sshfs.py +++ b/tests/test_sshfs.py @@ -2,11 +2,13 @@ from __future__ import absolute_import from __future__ import unicode_literals +import concurrent.futures import stat import sys import time import uuid import unittest +from concurrent.futures import ThreadPoolExecutor import paramiko.ssh_exception @@ -211,3 +213,18 @@ def test_setinfo(self): now = int(time.time()) with utils.mock.patch("time.time", lambda: now): super(TestSSHFS, self).test_setinfo() + + def test_thread_safty(self): + text = "Thread Safty Test" + self.fs.writetext("thread_safty.txt", text) + + def getinfo(): + return self.fs.getinfo("thread_safty.txt", namespaces=["basic", "details"]) + + info = getinfo() + self.assertEqual(len(text), info.size) + + with ThreadPoolExecutor(10) as e: + futures = [e.submit(getinfo) for _ in range(100)] + for f in futures: + self.assertEqual(info.size, f.result().size) From db5df08a981daff78b58b4ebc254e81a175f641f Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 11:11:12 +0200 Subject: [PATCH 5/9] test workflow: use linuxserver/openssh-server --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index abb58c8..a5d8d6b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: #- pypy-3.7 steps: - name: Checkout SSH docker container - run: docker pull sjourdan/alpine-sshd + run: docker pull linuxserver/openssh-server - name: Checkout code uses: actions/checkout@v3 - name: Setup Python ${{ matrix.python-version }} From 19e59f1b78576921d4b1ebf32061da3d7f3c6978 Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 11:20:53 +0200 Subject: [PATCH 6/9] Fix test dependencies --- tests/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 02c0e1a..883df9b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,5 +1,6 @@ codecov ~=2.1 -docker ~=7.1.0 +docker ~=6.1.3 +requests ==2.31.0 # https://github.com/docker/docker-py/issues/3256 urllib3 <2 semantic-version ~=2.6 mock ~=3.0.5 ; python_version < '3.3' From e114a2027634499c6c92187c9b430dd52b5d24f6 Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 11:43:39 +0200 Subject: [PATCH 7/9] Update docstrings and type hints --- fs/sshfs/pool.py | 38 ++++++++++++++++++++------------------ fs/sshfs/sshfs.py | 2 +- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/fs/sshfs/pool.py b/fs/sshfs/pool.py index 6594cec..c1835e2 100644 --- a/fs/sshfs/pool.py +++ b/fs/sshfs/pool.py @@ -7,16 +7,15 @@ _local = local() class ConnectionPool(object): + """A generic connection pool. - def __init__(self, open_func, max_connections=4, timeout=None): - """ - A generic connection pool. - - :param open_func: Function that opens a new connection. - :param max_connections: Maximum number of open connections - :param timeout: Maximum time to wait for a connection - """ + Arguments: + open_func (Callable): Function that opens a new connection. + max_connections (int): Maximum number of open connections + timeout (float): Maximum time to wait for a connection + """ + def __init__(self, open_func, max_connections=4, timeout=None): self._open_func = open_func self.timeout = timeout @@ -49,16 +48,18 @@ def release(self, conn): self._q.put(conn, block=False) class SFTPClientPool(ConnectionPool): - - def __init__(self, client:SSHClient, max_connections:int=4, timeout=None): - """ - A pool of SFTPClient sessions - - :param max_connections: Maximum number of open sessions - :param timeout: Maximum time to wait for a session - """ - - def open_sftp(conn: SFTPClient | None): + """A pool of SFTPClient sessions + + Arguments: + client (SSHClient): ssh client + max_connections (int): Maximum number of open sessions + timeout (float): Maximum time to wait for a session + """ + + def __init__(self, client, max_connections=4, timeout=None): + # type: (SSHClient, int, float | None) -> SFTPClientPool + def open_sftp(conn): + # type: (SFTPClient | None) -> SFTPClient if conn is None or conn.get_channel().closed: return client.open_sftp() return conn @@ -66,6 +67,7 @@ def open_sftp(conn: SFTPClient | None): super().__init__(open_sftp, max_connections, timeout) def acquire(self) -> SFTPClient: + # type: () -> SFTPClient """ Acquire an SFTPClient. If timeout is None this functions blocks until a free session is available. diff --git a/fs/sshfs/sshfs.py b/fs/sshfs/sshfs.py index 87d8555..a9c2d83 100644 --- a/fs/sshfs/sshfs.py +++ b/fs/sshfs/sshfs.py @@ -61,7 +61,7 @@ class SSHFS(FS): `paramiko.AutoAddPolicy` instance. max_connections (int): Maximum number of concurrent SFTPClient sessions (defaults to 4) - conn_timeout (int): The maximum time to wait for a free session. + conn_timeout (float): The maximum time to wait for a free session. Defaults to the value of ``timeout``. Raises: From c736739efff4d9ca09399f71307ed7515e68e4ca Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Tue, 24 Sep 2024 21:20:45 +0200 Subject: [PATCH 8/9] Fix ConnectionPool --- fs/sshfs/pool.py | 13 ++++++++++--- fs/sshfs/sshfs.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/fs/sshfs/pool.py b/fs/sshfs/pool.py index c1835e2..6f81043 100644 --- a/fs/sshfs/pool.py +++ b/fs/sshfs/pool.py @@ -30,19 +30,26 @@ def connection(self): Returns a ContextManager with an open connection. This function returns the same connection when called recursively. ''' - if getattr(_local, "conn", None) is None: + if not hasattr(_local, "conn"): try: _local.conn = self.acquire() yield _local.conn finally: + if not hasattr(_local, "conn"): + # nothing has been acquired + return self.release(_local.conn) - _local.conn = None + del _local.conn else: yield _local.conn def acquire(self): conn = self._q.get(timeout=self.timeout) - return self._open_func(conn) + try: + return self._open_func(conn) + except Exception: + self.release(conn) + raise def release(self, conn): self._q.put(conn, block=False) diff --git a/fs/sshfs/sshfs.py b/fs/sshfs/sshfs.py index a9c2d83..dda032c 100644 --- a/fs/sshfs/sshfs.py +++ b/fs/sshfs/sshfs.py @@ -292,7 +292,7 @@ def move(self, src_path, dst_path, overwrite=False, preserve_time=False): # preserve times if required if preserve_time: self._utime( - _path, + _dst_path, src_info.raw["details"]["modified"], src_info.raw["details"]["accessed"], ) From 7606808877bd0cc32f8733e4fe95fd01f9e5a238 Mon Sep 17 00:00:00 2001 From: Andreas Bichinger Date: Sat, 23 Aug 2025 17:56:23 +0200 Subject: [PATCH 9/9] Reconnect inactive SSH session --- fs/sshfs/pool.py | 13 +++++++++++-- fs/sshfs/sshfs.py | 11 +++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/fs/sshfs/pool.py b/fs/sshfs/pool.py index 6f81043..39e49f8 100644 --- a/fs/sshfs/pool.py +++ b/fs/sshfs/pool.py @@ -3,6 +3,7 @@ from queue import Queue from threading import local from paramiko import SSHClient, SFTPClient +import traceback _local = local() @@ -34,6 +35,9 @@ def connection(self): try: _local.conn = self.acquire() yield _local.conn + except Exception: + print(traceback.format_exc()) + raise finally: if not hasattr(_local, "conn"): # nothing has been acquired @@ -59,14 +63,19 @@ class SFTPClientPool(ConnectionPool): Arguments: client (SSHClient): ssh client + argdict (dict): client.connect parameters max_connections (int): Maximum number of open sessions timeout (float): Maximum time to wait for a session """ - def __init__(self, client, max_connections=4, timeout=None): - # type: (SSHClient, int, float | None) -> SFTPClientPool + def __init__(self, client, argdict, max_connections=4, timeout=None): + # type: (SSHClient, dict, int, float | None) -> SFTPClientPool def open_sftp(conn): # type: (SFTPClient | None) -> SFTPClient + transport = client.get_transport() + if not transport.is_active(): + client.connect(transport.hostname, **argdict) + if conn is None or conn.get_channel().closed: return client.open_sftp() return conn diff --git a/fs/sshfs/sshfs.py b/fs/sshfs/sshfs.py index dda032c..b87555b 100644 --- a/fs/sshfs/sshfs.py +++ b/fs/sshfs/sshfs.py @@ -137,23 +137,22 @@ def __init__( client.load_system_host_keys() client.set_missing_host_key_policy(_policy) argdict = { + "port": port, + "username": user, + "password": passwd, "pkey": pkey, "key_filename": keyfile, "look_for_keys": True if (pkey and keyfile) is None else False, "compress": compress, "timeout": timeout } - argdict.update(kwargs) - client.connect( - socket.gethostbyname(host), port, user, passwd, - **argdict - ) + client.connect(socket.gethostbyname(host), **argdict) if keepalive > 0: client.get_transport().set_keepalive(keepalive) - self._pool = SFTPClientPool(client, max_connections, timeout=self._conn_timeout) + self._pool = SFTPClientPool(client, argdict, max_connections, timeout=self._conn_timeout) except (paramiko.ssh_exception.SSHException, # protocol errors paramiko.ssh_exception.NoValidConnectionsError, # connexion errors