diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index abb58c8..a5d8d6b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: #- pypy-3.7 steps: - name: Checkout SSH docker container - run: docker pull sjourdan/alpine-sshd + run: docker pull linuxserver/openssh-server - name: Checkout code uses: actions/checkout@v3 - name: Setup Python ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index bd59574..65f0105 100644 --- a/.gitignore +++ b/.gitignore @@ -79,6 +79,7 @@ celerybeat-schedule .env # virtualenv +.venv/ venv/ ENV/ diff --git a/fs/sshfs/pool.py b/fs/sshfs/pool.py new file mode 100644 index 0000000..39e49f8 --- /dev/null +++ b/fs/sshfs/pool.py @@ -0,0 +1,93 @@ + +import contextlib +from queue import Queue +from threading import local +from paramiko import SSHClient, SFTPClient +import traceback + +_local = local() + +class ConnectionPool(object): + """A generic connection pool. + + Arguments: + open_func (Callable): Function that opens a new connection. + max_connections (int): Maximum number of open connections + timeout (float): Maximum time to wait for a connection + """ + + def __init__(self, open_func, max_connections=4, timeout=None): + self._open_func = open_func + self.timeout = timeout + + self._q = Queue(maxsize=max_connections) + + while not self._q.full(): + self._q.put(None) + + @contextlib.contextmanager + def connection(self): + ''' + Returns a ContextManager with an open connection. + This function returns the same connection when called recursively. + ''' + if not hasattr(_local, "conn"): + try: + _local.conn = self.acquire() + yield _local.conn + except Exception: + print(traceback.format_exc()) + raise + finally: + if not hasattr(_local, "conn"): + # nothing has been acquired + return + self.release(_local.conn) + del _local.conn + else: + yield _local.conn + + def acquire(self): + conn = self._q.get(timeout=self.timeout) + try: + return self._open_func(conn) + except Exception: + self.release(conn) + raise + + def release(self, conn): + self._q.put(conn, block=False) + +class SFTPClientPool(ConnectionPool): + """A pool of SFTPClient sessions + + Arguments: + client (SSHClient): ssh client + argdict (dict): client.connect parameters + max_connections (int): Maximum number of open sessions + timeout (float): Maximum time to wait for a session + """ + + def __init__(self, client, argdict, max_connections=4, timeout=None): + # type: (SSHClient, dict, int, float | None) -> SFTPClientPool + def open_sftp(conn): + # type: (SFTPClient | None) -> SFTPClient + transport = client.get_transport() + if not transport.is_active(): + client.connect(transport.hostname, **argdict) + + if conn is None or conn.get_channel().closed: + return client.open_sftp() + return conn + + super().__init__(open_sftp, max_connections, timeout) + + def acquire(self) -> SFTPClient: + # type: () -> SFTPClient + """ + Acquire an SFTPClient. + If timeout is None this functions blocks until a free session is available. + Otherwise an Empty exception is raised. + """ + return super().acquire() + diff --git a/fs/sshfs/sshfs.py b/fs/sshfs/sshfs.py index aaa4cda..b87555b 100644 --- a/fs/sshfs/sshfs.py +++ b/fs/sshfs/sshfs.py @@ -25,6 +25,7 @@ from .file import SSHFile from .error_tools import convert_sshfs_errors +from .pool import SFTPClientPool class SSHFS(FS): @@ -58,6 +59,10 @@ class SSHFS(FS): policy (paramiko.MissingHostKeyPolicy): The policy to use to resolve missing host keys. Defaults to ``None``, which will create a `paramiko.AutoAddPolicy` instance. + max_connections (int): Maximum number of concurrent SFTPClient sessions + (defaults to 4) + conn_timeout (float): The maximum time to wait for a free session. + Defaults to the value of ``timeout``. Raises: fs.errors.CreateFailed: when the filesystem could not be created. The @@ -104,6 +109,8 @@ def __init__( config_path='~/.ssh/config', exec_timeout=None, policy=None, + max_connections=4, + conn_timeout=None, **kwargs ): # noqa: D102 super(SSHFS, self).__init__() @@ -121,6 +128,7 @@ def __init__( self._client = client = paramiko.SSHClient() self._timeout = timeout self._exec_timeout = timeout if exec_timeout is None else exec_timeout + self._conn_timeout = timeout if conn_timeout is None else conn_timeout _policy = paramiko.AutoAddPolicy() if policy is None else policy @@ -129,23 +137,22 @@ def __init__( client.load_system_host_keys() client.set_missing_host_key_policy(_policy) argdict = { + "port": port, + "username": user, + "password": passwd, "pkey": pkey, "key_filename": keyfile, "look_for_keys": True if (pkey and keyfile) is None else False, "compress": compress, "timeout": timeout } - argdict.update(kwargs) - client.connect( - socket.gethostbyname(host), port, user, passwd, - **argdict - ) + client.connect(socket.gethostbyname(host), **argdict) if keepalive > 0: client.get_transport().set_keepalive(keepalive) - self._sftp = client.open_sftp() + self._pool = SFTPClientPool(client, argdict, max_connections, timeout=self._conn_timeout) except (paramiko.ssh_exception.SSHException, # protocol errors paramiko.ssh_exception.NoValidConnectionsError, # connexion errors @@ -166,17 +173,21 @@ def close(self): # noqa: D102 self._client.close() super().close() + + def _sftp(self): + return self._pool.connection() + def getinfo(self, path, namespaces=None): # noqa: D102 self.check() namespaces = namespaces or () _path = self.validatepath(path) - with convert_sshfs_errors('getinfo', path): - _stat = self._sftp.stat(_path) + with convert_sshfs_errors('getinfo', path), self._sftp() as sftp: + _stat = sftp.stat(_path) info = self._make_raw_info(basename(_path), _stat, namespaces) if "lstat" in namespaces or "link" in namespaces: - _lstat = self._sftp.lstat(_path) + _lstat = sftp.lstat(_path) if "lstat" in namespaces: info["lstat"] = { k: getattr(_lstat, k) @@ -185,7 +196,7 @@ def getinfo(self, path, namespaces=None): # noqa: D102 } if "link" in namespaces: if OSFS._get_type_from_stat(_lstat) == ResourceType.symlink: - target = self._sftp.readlink(_path) + target = sftp.readlink(_path) info["link"] = {"target": target} else: info["link"] = {"target": None} @@ -210,8 +221,8 @@ def listdir(self, path): # noqa: D102 if _type is not ResourceType.directory: raise errors.DirectoryExpected(path) - with convert_sshfs_errors('listdir', path): - return self._sftp.listdir(_path) + with convert_sshfs_errors('listdir', path), self._sftp() as sftp: + return sftp.listdir(_path) def scandir(self, path, namespaces=None, page=None): # noqa: D102 self.check() @@ -219,11 +230,11 @@ def scandir(self, path, namespaces=None, page=None): # noqa: D102 _namespaces = namespaces or () start, stop = page or (None, None) try: - with convert_sshfs_errors('scandir', path, directory=True): + with convert_sshfs_errors('scandir', path, directory=True), self._sftp() as sftp: # We can't use listdir_iter here because it doesn't support # concurrent iteration over multiple directories, which can # happen during a search="depth" walk. - listing = self._sftp.listdir_attr(_path) + listing = sftp.listdir_attr(_path) for _stat in itertools.islice(listing, start, stop): yield Info(self._make_raw_info(_stat.filename, _stat, _namespaces)) except errors.ResourceNotFound: @@ -244,8 +255,8 @@ def makedir(self, path, permissions=None, recreate=False): # noqa: D102 info = self.getinfo(_path) except errors.ResourceNotFound: with self._lock: - with convert_sshfs_errors('makedir', path): - self._sftp.mkdir(_path, _permissions.mode) + with convert_sshfs_errors('makedir', path), self._sftp() as sftp: + sftp.mkdir(_path, _permissions.mode) else: if (info.is_dir and not recreate) or info.is_file: six.raise_from(errors.DirectoryExists(path), None) @@ -259,7 +270,7 @@ def move(self, src_path, dst_path, overwrite=False, preserve_time=False): _src_path = self.validatepath(src_path) _dst_path = self.validatepath(dst_path) - with self._lock: + with self._lock, self._sftp() as sftp: # check src exists and is a file src_info = self.getinfo(_src_path, namespaces=info_ns) if src_info.is_dir: @@ -274,13 +285,13 @@ def move(self, src_path, dst_path, overwrite=False, preserve_time=False): if not overwrite: raise errors.DestinationExists(dst_path) with convert_sshfs_errors('move', dst_path): - self._sftp.remove(_dst_path) + sftp.remove(_dst_path) # rename the file through SFTP's 'RENAME' - self._sftp.rename(_src_path, _dst_path) + sftp.rename(_src_path, _dst_path) # preserve times if required if preserve_time: self._utime( - _path, + _dst_path, src_info.raw["details"]["modified"], src_info.raw["details"]["accessed"], ) @@ -329,12 +340,21 @@ def openbin(self, path, mode='r', buffering=-1, **options): # noqa: D102 elif self.isdir(_path): raise errors.FileExpected(path) with convert_sshfs_errors('openbin', path): - _sftp = self._client.open_sftp() - handle = _sftp.open( + sftp = self._pool.acquire() + handle = sftp.open( _path, mode=_mode.to_platform_bin(), bufsize=buffering ) + + # release sftp client on close + def wrap_close(fn): + def close(): + fn() + self._pool.release(sftp) + return close + handle.close = wrap_close(handle.close) + handle.set_pipelined(options.get("pipelined", True)) if options.get("prefetch", True): if _mode.reading and not _mode.writing: @@ -350,9 +370,9 @@ def remove(self, path): # noqa: D102 if self.getinfo(_path).is_dir: raise errors.FileExpected(path) - with convert_sshfs_errors('remove', path): + with convert_sshfs_errors('remove', path), self._sftp() as sftp: with self._lock: - self._sftp.remove(_path) + sftp.remove(_path) def removedir(self, path): # noqa: D102 self.check() @@ -364,9 +384,8 @@ def removedir(self, path): # noqa: D102 if not self.isempty(_path): raise errors.DirectoryNotEmpty(path) - with convert_sshfs_errors('removedir', path): - with self._lock: - self._sftp.rmdir(_path) + with convert_sshfs_errors('removedir', path), self._lock, self._sftp() as sftp: + sftp.rmdir(_path) def setinfo(self, path, info): # noqa: D102 self.check() @@ -423,8 +442,8 @@ def download(self, path, file, chunk_size=None, callback=None, **options): raise errors.ResourceNotFound(path) elif self.isdir(_path): raise errors.FileExpected(path) - with convert_sshfs_errors('download', path): - self._sftp.getfo(_path, file, callback=callback) + with convert_sshfs_errors('download', path), self._sftp() as sftp: + sftp.getfo(_path, file, callback=callback) def upload(self, path, file, chunk_size=None, callback=None, file_size=None, confirm=True, **options): """Set a file to the contents of a binary file object. @@ -466,8 +485,8 @@ def upload(self, path, file, chunk_size=None, callback=None, file_size=None, con raise errors.ResourceNotFound(path) elif self.isdir(_path): raise errors.FileExpected(path) - with convert_sshfs_errors('upload', path): - self._sftp.putfo( + with convert_sshfs_errors('upload', path), self._sftp() as sftp: + sftp.putfo( file, _path, file_size=file_size, @@ -580,23 +599,27 @@ def entry_name(db, _id): def _chmod(self, path, mode): """Change the *mode* of a resource. """ - self._sftp.chmod(path, mode) + with self._sftp() as sftp: + sftp.chmod(path, mode) def _chown(self, path, uid, gid): """Change the *owner* of a resource. """ - if uid is None or gid is None: - info = self.getinfo(path, namespaces=('access',)) - uid = uid or info.get('access', 'uid') - gid = gid or info.get('access', 'gid') - self._sftp.chown(path, uid, gid) + with self._sftp() as sftp: + if uid is None or gid is None: + info = self.getinfo(path, namespaces=('access',)) + uid = uid or info.get('access', 'uid') + gid = gid or info.get('access', 'gid') + + sftp.chown(path, uid, gid) def _utime(self, path, modified, accessed): """Set the *accessed* and *modified* times of a resource. """ - if accessed is not None or modified is not None: - accessed = float(modified if accessed is None else accessed) - modified = float(accessed if modified is None else modified) - self._sftp.utime(path, (accessed, modified)) - else: - self._sftp.utime(path, None) + with self._sftp() as sftp: + if accessed is not None or modified is not None: + accessed = float(modified if accessed is None else accessed) + modified = float(accessed if modified is None else modified) + sftp.utime(path, (accessed, modified)) + else: + sftp.utime(path, None) diff --git a/tests/requirements.txt b/tests/requirements.txt index 3053826..883df9b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,5 +1,6 @@ codecov ~=2.1 -docker ~=4.4 +docker ~=6.1.3 +requests ==2.31.0 # https://github.com/docker/docker-py/issues/3256 urllib3 <2 semantic-version ~=2.6 mock ~=3.0.5 ; python_version < '3.3' diff --git a/tests/test_opener.py b/tests/test_opener.py index ad516df..03568e5 100644 --- a/tests/test_opener.py +++ b/tests/test_opener.py @@ -45,13 +45,14 @@ def tearDownClass(cls): @classmethod def addKeyToServer(cls, pkey): + home = '/config' with paramiko.SSHClient() as client: client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect('localhost', cls.port, cls.user, cls.pasw) with client.open_sftp() as sftp: - if not '.ssh' in sftp.listdir('/home/{}'.format(cls.user)): - sftp.mkdir('/home/{}/.ssh'.format(cls.user)) - with sftp.open('/home/{}/.ssh/authorized_keys'.format(cls.user), 'w') as f: + if '.ssh' not in sftp.listdir(home): + sftp.mkdir('{}/.ssh'.format(home)) + with sftp.open('{}/.ssh/authorized_keys'.format(home), 'w') as f: f.write("{} {}\n".format( pkey.get_name(), pkey.get_base64()).encode('utf-8')) @@ -72,7 +73,7 @@ def tearDown(self): os.remove(self.config_file) def assertFunctional(self, ssh_fs): - test_folder = '/home/{}/{}'.format(self.user, uuid.uuid4().hex) + test_folder = '/config/{}'.format(uuid.uuid4().hex) ssh_fs.makedir(test_folder) with ssh_fs.opendir(test_folder) as test_fs: @@ -141,7 +142,7 @@ def test_sshconfig_notfound(self): def test_create(self): - directory = fs.path.join("home", self.user, "test", "directory") + directory = fs.path.join("config", "test", "directory") base = "ssh://{}:{}@localhost:{}".format(self.user, self.pasw, self.port) url = "{}/{}".format(base, directory) @@ -167,20 +168,21 @@ def test_create(self): self.assertTrue(ssh_fs.isfile("foo")) def test_open_symlink(self): + home = '/config' # create a symlink in the home directory with paramiko.SSHClient() as client: client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect('localhost', self.port, self.user, self.pasw) with client.open_sftp() as sftp: - sftp.mkdir('/home/{}/directory'.format(self.user)) + sftp.mkdir('{}/directory'.format(home)) sftp.symlink( - '/home/{}/directory'.format(self.user), - '/home/{}/link'.format(self.user) + '{}/directory'.format(home), + '{}/link'.format(home) ) - sftp.open('/home/{}/directory/test'.format(self.user), 'w').close() + sftp.open('{}/directory/test'.format(home), 'w').close() # check the symlink can be opened - directory = fs.path.join("home", self.user, "link") + directory = fs.path.join(home, "link") base = "ssh://{}:{}@localhost:{}".format(self.user, self.pasw, self.port) url = "{}/{}".format(base, directory) with fs.open_fs(url) as ssh_fs: diff --git a/tests/test_sshfs.py b/tests/test_sshfs.py index 77d48c2..38dc80d 100644 --- a/tests/test_sshfs.py +++ b/tests/test_sshfs.py @@ -2,11 +2,13 @@ from __future__ import absolute_import from __future__ import unicode_literals +import concurrent.futures import stat import sys import time import uuid import unittest +from concurrent.futures import ThreadPoolExecutor import paramiko.ssh_exception @@ -45,7 +47,7 @@ def destroy_fs(fs): def make_fs(self): self.ssh_fs = SSHFS('localhost', self.user, self.pasw, port=self.port) - self.test_folder = fs.path.join('/home', self.user, uuid.uuid4().hex) + self.test_folder = fs.path.join('/config', uuid.uuid4().hex) self.ssh_fs.makedir(self.test_folder, recreate=True) return self.ssh_fs.opendir(self.test_folder, factory=ClosingSubFS) @@ -81,6 +83,9 @@ def test_upload_2(self): def test_upload_4(self): super(TestSSHFS, self).test_upload_4() + def _sftp(self): + return self.fs.delegate_fs()._sftp() + def test_chmod(self): self.fs.touch("test.txt") remote_path = fs.path.join(self.test_folder, "test.txt") @@ -88,14 +93,16 @@ def test_chmod(self): # Initial permissions info = self.fs.getinfo("test.txt", ["access"]) self.assertEqual(info.permissions.mode, 0o644) - st = self.fs.delegate_fs()._sftp.stat(remote_path) + with self._sftp() as sftp: + st = sftp.stat(remote_path) self.assertEqual(stat.S_IMODE(st.st_mode), 0o644) # Change permissions with SSHFS._chown self.fs.delegate_fs()._chmod(remote_path, 0o744) info = self.fs.getinfo("test.txt", ["access"]) self.assertEqual(info.permissions.mode, 0o744) - st = self.fs.delegate_fs()._sftp.stat(remote_path) + with self._sftp() as sftp: + st = sftp.stat(remote_path) self.assertEqual(stat.S_IMODE(st.st_mode), 0o744) # Change permissions with SSHFS.setinfo @@ -103,7 +110,8 @@ def test_chmod(self): {"access": {"permissions": Permissions(mode=0o600)}}) info = self.fs.getinfo("test.txt", ["access"]) self.assertEqual(info.permissions.mode, 0o600) - st = self.fs.delegate_fs()._sftp.stat(remote_path) + with self._sftp() as sftp: + st = sftp.stat(remote_path) self.assertEqual(stat.S_IMODE(st.st_mode), 0o600) with self.assertRaises(fs.errors.PermissionDenied): @@ -111,6 +119,7 @@ def test_chmod(self): "access": {"permissions": Permissions(mode=0o777)} }) + def test_chown(self): self.fs.touch("test.txt") @@ -118,18 +127,19 @@ def test_chown(self): info = self.fs.getinfo("test.txt", namespaces=["access"]) gid, uid = info.get('access', 'uid'), info.get('access', 'gid') - with utils.mock.patch.object(self.fs.delegate_fs()._sftp, 'chown') as chown: - self.fs.setinfo("test.txt", {'access': {'uid': None}}) - chown.assert_called_with(remote_path, uid, gid) + with self._sftp() as sftp: + with utils.mock.patch.object(sftp, 'chown') as chown: + self.fs.setinfo("test.txt", {'access': {'uid': None}}) + chown.assert_called_with(remote_path, uid, gid) - self.fs.setinfo("test.txt", {'access': {'gid': None}}) - chown.assert_called_with(remote_path, uid, gid) + self.fs.setinfo("test.txt", {'access': {'gid': None}}) + chown.assert_called_with(remote_path, uid, gid) - self.fs.setinfo("test.txt", {'access': {'gid': 8000}}) - chown.assert_called_with(remote_path, uid, 8000) + self.fs.setinfo("test.txt", {'access': {'gid': 8000}}) + chown.assert_called_with(remote_path, uid, 8000) - self.fs.setinfo("test.txt", {'access': {'uid': 1001, 'gid':1002}}) - chown.assert_called_with(remote_path, 1001, 1002) + self.fs.setinfo("test.txt", {'access': {'uid': 1001, 'gid':1002}}) + chown.assert_called_with(remote_path, 1001, 1002) def test_exec_command_exception(self): ssh = self.fs.delegate_fs() @@ -175,10 +185,11 @@ def test_symlinks(self): with self.fs.openbin("foo", "wb") as f: f.write(b"foobar") - self.fs.delegate_fs()._sftp.symlink( - fs.path.join(self.test_folder, "foo"), - fs.path.join(self.test_folder, "bar") - ) + with self._sftp() as sftp: + sftp.symlink( + fs.path.join(self.test_folder, "foo"), + fs.path.join(self.test_folder, "bar") + ) # os.symlink(self._get_real_path("foo"), self._get_real_path("bar")) self.assertFalse(self.fs.islink("foo")) @@ -202,3 +213,18 @@ def test_setinfo(self): now = int(time.time()) with utils.mock.patch("time.time", lambda: now): super(TestSSHFS, self).test_setinfo() + + def test_thread_safty(self): + text = "Thread Safty Test" + self.fs.writetext("thread_safty.txt", text) + + def getinfo(): + return self.fs.getinfo("thread_safty.txt", namespaces=["basic", "details"]) + + info = getinfo() + self.assertEqual(len(text), info.size) + + with ThreadPoolExecutor(10) as e: + futures = [e.submit(getinfo) for _ in range(100)] + for f in futures: + self.assertEqual(info.size, f.result().size) diff --git a/tests/utils.py b/tests/utils.py index 10d0f13..197c183 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,8 +29,8 @@ def fs_version(): def startServer(docker_client, user, pasw, port): sftp_container = docker_client.containers.run( - "sjourdan/alpine-sshd", detach=True, ports={'22/tcp': port}, - environment={'USER': user, 'PASSWORD': pasw}, + "lscr.io/linuxserver/openssh-server", detach=True, ports={'2222/tcp': port}, + environment={'USER_NAME': user, 'USER_PASSWORD': pasw, "PASSWORD_ACCESS": "true"}, ) time.sleep(1) return sftp_container