Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
#- pypy-3.7
steps:
- name: Checkout SSH docker container
run: docker pull sjourdan/alpine-sshd
run: docker pull linuxserver/openssh-server
- name: Checkout code
uses: actions/checkout@v3
- name: Setup Python ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ celerybeat-schedule
.env

# virtualenv
.venv/
venv/
ENV/

Expand Down
93 changes: 93 additions & 0 deletions fs/sshfs/pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

import contextlib
from queue import Queue
from threading import local
from paramiko import SSHClient, SFTPClient
import traceback

_local = local()

class ConnectionPool(object):
"""A generic connection pool.

Arguments:
open_func (Callable): Function that opens a new connection.
max_connections (int): Maximum number of open connections
timeout (float): Maximum time to wait for a connection
"""

def __init__(self, open_func, max_connections=4, timeout=None):
self._open_func = open_func
self.timeout = timeout

self._q = Queue(maxsize=max_connections)

while not self._q.full():
self._q.put(None)

@contextlib.contextmanager
def connection(self):
'''
Returns a ContextManager with an open connection.
This function returns the same connection when called recursively.
'''
if not hasattr(_local, "conn"):
try:
_local.conn = self.acquire()
yield _local.conn
except Exception:
print(traceback.format_exc())
raise
finally:
if not hasattr(_local, "conn"):
# nothing has been acquired
return
self.release(_local.conn)
del _local.conn
else:
yield _local.conn

def acquire(self):
conn = self._q.get(timeout=self.timeout)
try:
return self._open_func(conn)
except Exception:
self.release(conn)
raise

def release(self, conn):
self._q.put(conn, block=False)

class SFTPClientPool(ConnectionPool):
"""A pool of SFTPClient sessions

Arguments:
client (SSHClient): ssh client
argdict (dict): client.connect parameters
max_connections (int): Maximum number of open sessions
timeout (float): Maximum time to wait for a session
"""

def __init__(self, client, argdict, max_connections=4, timeout=None):
# type: (SSHClient, dict, int, float | None) -> SFTPClientPool
def open_sftp(conn):
# type: (SFTPClient | None) -> SFTPClient
transport = client.get_transport()
if not transport.is_active():
client.connect(transport.hostname, **argdict)

if conn is None or conn.get_channel().closed:
return client.open_sftp()
return conn

super().__init__(open_sftp, max_connections, timeout)

def acquire(self) -> SFTPClient:
# type: () -> SFTPClient
"""
Acquire an SFTPClient.
If timeout is None this functions blocks until a free session is available.
Otherwise an Empty exception is raised.
"""
return super().acquire()

109 changes: 66 additions & 43 deletions fs/sshfs/sshfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from .file import SSHFile
from .error_tools import convert_sshfs_errors
from .pool import SFTPClientPool


class SSHFS(FS):
Expand Down Expand Up @@ -58,6 +59,10 @@ class SSHFS(FS):
policy (paramiko.MissingHostKeyPolicy): The policy to use to resolve
missing host keys. Defaults to ``None``, which will create a
`paramiko.AutoAddPolicy` instance.
max_connections (int): Maximum number of concurrent SFTPClient sessions
(defaults to 4)
conn_timeout (float): The maximum time to wait for a free session.
Defaults to the value of ``timeout``.

Raises:
fs.errors.CreateFailed: when the filesystem could not be created. The
Expand Down Expand Up @@ -104,6 +109,8 @@ def __init__(
config_path='~/.ssh/config',
exec_timeout=None,
policy=None,
max_connections=4,
conn_timeout=None,
**kwargs
): # noqa: D102
super(SSHFS, self).__init__()
Expand All @@ -121,6 +128,7 @@ def __init__(
self._client = client = paramiko.SSHClient()
self._timeout = timeout
self._exec_timeout = timeout if exec_timeout is None else exec_timeout
self._conn_timeout = timeout if conn_timeout is None else conn_timeout

_policy = paramiko.AutoAddPolicy() if policy is None else policy

Expand All @@ -129,23 +137,22 @@ def __init__(
client.load_system_host_keys()
client.set_missing_host_key_policy(_policy)
argdict = {
"port": port,
"username": user,
"password": passwd,
"pkey": pkey,
"key_filename": keyfile,
"look_for_keys": True if (pkey and keyfile) is None else False,
"compress": compress,
"timeout": timeout
}

argdict.update(kwargs)

client.connect(
socket.gethostbyname(host), port, user, passwd,
**argdict
)
client.connect(socket.gethostbyname(host), **argdict)

if keepalive > 0:
client.get_transport().set_keepalive(keepalive)
self._sftp = client.open_sftp()
self._pool = SFTPClientPool(client, argdict, max_connections, timeout=self._conn_timeout)

except (paramiko.ssh_exception.SSHException, # protocol errors
paramiko.ssh_exception.NoValidConnectionsError, # connexion errors
Expand All @@ -166,17 +173,21 @@ def close(self): # noqa: D102
self._client.close()
super().close()


def _sftp(self):
return self._pool.connection()

def getinfo(self, path, namespaces=None): # noqa: D102
self.check()
namespaces = namespaces or ()
_path = self.validatepath(path)

with convert_sshfs_errors('getinfo', path):
_stat = self._sftp.stat(_path)
with convert_sshfs_errors('getinfo', path), self._sftp() as sftp:
_stat = sftp.stat(_path)
info = self._make_raw_info(basename(_path), _stat, namespaces)

if "lstat" in namespaces or "link" in namespaces:
_lstat = self._sftp.lstat(_path)
_lstat = sftp.lstat(_path)
if "lstat" in namespaces:
info["lstat"] = {
k: getattr(_lstat, k)
Expand All @@ -185,7 +196,7 @@ def getinfo(self, path, namespaces=None): # noqa: D102
}
if "link" in namespaces:
if OSFS._get_type_from_stat(_lstat) == ResourceType.symlink:
target = self._sftp.readlink(_path)
target = sftp.readlink(_path)
info["link"] = {"target": target}
else:
info["link"] = {"target": None}
Expand All @@ -210,20 +221,20 @@ def listdir(self, path): # noqa: D102
if _type is not ResourceType.directory:
raise errors.DirectoryExpected(path)

with convert_sshfs_errors('listdir', path):
return self._sftp.listdir(_path)
with convert_sshfs_errors('listdir', path), self._sftp() as sftp:
return sftp.listdir(_path)

def scandir(self, path, namespaces=None, page=None): # noqa: D102
self.check()
_path = self.validatepath(path)
_namespaces = namespaces or ()
start, stop = page or (None, None)
try:
with convert_sshfs_errors('scandir', path, directory=True):
with convert_sshfs_errors('scandir', path, directory=True), self._sftp() as sftp:
# We can't use listdir_iter here because it doesn't support
# concurrent iteration over multiple directories, which can
# happen during a search="depth" walk.
listing = self._sftp.listdir_attr(_path)
listing = sftp.listdir_attr(_path)
for _stat in itertools.islice(listing, start, stop):
yield Info(self._make_raw_info(_stat.filename, _stat, _namespaces))
except errors.ResourceNotFound:
Expand All @@ -244,8 +255,8 @@ def makedir(self, path, permissions=None, recreate=False): # noqa: D102
info = self.getinfo(_path)
except errors.ResourceNotFound:
with self._lock:
with convert_sshfs_errors('makedir', path):
self._sftp.mkdir(_path, _permissions.mode)
with convert_sshfs_errors('makedir', path), self._sftp() as sftp:
sftp.mkdir(_path, _permissions.mode)
else:
if (info.is_dir and not recreate) or info.is_file:
six.raise_from(errors.DirectoryExists(path), None)
Expand All @@ -259,7 +270,7 @@ def move(self, src_path, dst_path, overwrite=False, preserve_time=False):
_src_path = self.validatepath(src_path)
_dst_path = self.validatepath(dst_path)

with self._lock:
with self._lock, self._sftp() as sftp:
# check src exists and is a file
src_info = self.getinfo(_src_path, namespaces=info_ns)
if src_info.is_dir:
Expand All @@ -274,13 +285,13 @@ def move(self, src_path, dst_path, overwrite=False, preserve_time=False):
if not overwrite:
raise errors.DestinationExists(dst_path)
with convert_sshfs_errors('move', dst_path):
self._sftp.remove(_dst_path)
sftp.remove(_dst_path)
# rename the file through SFTP's 'RENAME'
self._sftp.rename(_src_path, _dst_path)
sftp.rename(_src_path, _dst_path)
# preserve times if required
if preserve_time:
self._utime(
_path,
_dst_path,
src_info.raw["details"]["modified"],
src_info.raw["details"]["accessed"],
)
Expand Down Expand Up @@ -329,12 +340,21 @@ def openbin(self, path, mode='r', buffering=-1, **options): # noqa: D102
elif self.isdir(_path):
raise errors.FileExpected(path)
with convert_sshfs_errors('openbin', path):
_sftp = self._client.open_sftp()
handle = _sftp.open(
sftp = self._pool.acquire()
handle = sftp.open(
_path,
mode=_mode.to_platform_bin(),
bufsize=buffering
)

# release sftp client on close
def wrap_close(fn):
def close():
fn()
self._pool.release(sftp)
return close
handle.close = wrap_close(handle.close)

handle.set_pipelined(options.get("pipelined", True))
if options.get("prefetch", True):
if _mode.reading and not _mode.writing:
Expand All @@ -350,9 +370,9 @@ def remove(self, path): # noqa: D102
if self.getinfo(_path).is_dir:
raise errors.FileExpected(path)

with convert_sshfs_errors('remove', path):
with convert_sshfs_errors('remove', path), self._sftp() as sftp:
with self._lock:
self._sftp.remove(_path)
sftp.remove(_path)

def removedir(self, path): # noqa: D102
self.check()
Expand All @@ -364,9 +384,8 @@ def removedir(self, path): # noqa: D102
if not self.isempty(_path):
raise errors.DirectoryNotEmpty(path)

with convert_sshfs_errors('removedir', path):
with self._lock:
self._sftp.rmdir(_path)
with convert_sshfs_errors('removedir', path), self._lock, self._sftp() as sftp:
sftp.rmdir(_path)

def setinfo(self, path, info): # noqa: D102
self.check()
Expand Down Expand Up @@ -423,8 +442,8 @@ def download(self, path, file, chunk_size=None, callback=None, **options):
raise errors.ResourceNotFound(path)
elif self.isdir(_path):
raise errors.FileExpected(path)
with convert_sshfs_errors('download', path):
self._sftp.getfo(_path, file, callback=callback)
with convert_sshfs_errors('download', path), self._sftp() as sftp:
sftp.getfo(_path, file, callback=callback)

def upload(self, path, file, chunk_size=None, callback=None, file_size=None, confirm=True, **options):
"""Set a file to the contents of a binary file object.
Expand Down Expand Up @@ -466,8 +485,8 @@ def upload(self, path, file, chunk_size=None, callback=None, file_size=None, con
raise errors.ResourceNotFound(path)
elif self.isdir(_path):
raise errors.FileExpected(path)
with convert_sshfs_errors('upload', path):
self._sftp.putfo(
with convert_sshfs_errors('upload', path), self._sftp() as sftp:
sftp.putfo(
file,
_path,
file_size=file_size,
Expand Down Expand Up @@ -580,23 +599,27 @@ def entry_name(db, _id):
def _chmod(self, path, mode):
"""Change the *mode* of a resource.
"""
self._sftp.chmod(path, mode)
with self._sftp() as sftp:
sftp.chmod(path, mode)

def _chown(self, path, uid, gid):
"""Change the *owner* of a resource.
"""
if uid is None or gid is None:
info = self.getinfo(path, namespaces=('access',))
uid = uid or info.get('access', 'uid')
gid = gid or info.get('access', 'gid')
self._sftp.chown(path, uid, gid)
with self._sftp() as sftp:
if uid is None or gid is None:
info = self.getinfo(path, namespaces=('access',))
uid = uid or info.get('access', 'uid')
gid = gid or info.get('access', 'gid')

sftp.chown(path, uid, gid)

def _utime(self, path, modified, accessed):
"""Set the *accessed* and *modified* times of a resource.
"""
if accessed is not None or modified is not None:
accessed = float(modified if accessed is None else accessed)
modified = float(accessed if modified is None else modified)
self._sftp.utime(path, (accessed, modified))
else:
self._sftp.utime(path, None)
with self._sftp() as sftp:
if accessed is not None or modified is not None:
accessed = float(modified if accessed is None else accessed)
modified = float(accessed if modified is None else modified)
sftp.utime(path, (accessed, modified))
else:
sftp.utime(path, None)
3 changes: 2 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
codecov ~=2.1
docker ~=4.4
docker ~=6.1.3
requests ==2.31.0 # https://github.com/docker/docker-py/issues/3256
urllib3 <2
semantic-version ~=2.6
mock ~=3.0.5 ; python_version < '3.3'
Loading