diff --git a/datashuttle/configs/aws_regions.py b/datashuttle/configs/aws_regions.py new file mode 100644 index 00000000..cd4a03e5 --- /dev/null +++ b/datashuttle/configs/aws_regions.py @@ -0,0 +1,39 @@ +from typing import List, Literal, get_args + +# ----------------------------------------------------------------------------- +# AWS regions +# ----------------------------------------------------------------------------- + +AwsRegion = Literal[ + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", + "ca-central-1", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-north-1", + "eu-south-1", + "eu-central-1", + "ap-southeast-1", + "ap-southeast-2", + "ap-northeast-1", + "ap-northeast-2", + "ap-northeast-3", + "ap-south-1", + "ap-east-1", + "sa-east-1", + "il-central-1", + "me-south-1", + "af-south-1", + "cn-north-1", + "cn-northwest-1", + "us-gov-east-1", + "us-gov-west-1", +] + + +def get_aws_regions_list() -> List[str]: + """Return AWS S3 bucket regions as a list.""" + return list(get_args(AwsRegion)) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 542cf7f7..e49472b4 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -17,6 +17,7 @@ Literal, Optional, Union, + get_args, ) if TYPE_CHECKING: @@ -25,18 +26,30 @@ import typeguard +from datashuttle.configs.aws_regions import AwsRegion from datashuttle.utils import folders, utils from datashuttle.utils.custom_exceptions import ConfigError +connection_methods = Literal["ssh", "local_filesystem", "gdrive", "aws"] + + +def get_connection_methods_list() -> List[str]: + """Return the canonical connection methods.""" + return list(get_args(connection_methods)) + def get_canonical_configs() -> dict: """Return the only permitted types for DataShuttle config values.""" canonical_configs = { "local_path": Union[str, Path], "central_path": Optional[Union[str, Path]], - "connection_method": Optional[Literal["ssh", "local_filesystem"]], + "connection_method": Optional[connection_methods], "central_host_id": Optional[str], "central_host_username": Optional[str], + "gdrive_client_id": Optional[str], + "gdrive_root_folder_id": Optional[str], + "aws_access_key_id": Optional[str], + "aws_region": Optional[AwsRegion], } return canonical_configs @@ -101,7 +114,8 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: check_config_types(config_dict) - raise_on_bad_local_only_project_configs(config_dict) + if config_dict["connection_method"] not in ["aws", "gdrive"]: + raise_on_bad_local_only_project_configs(config_dict) if list(config_dict.keys()) != list(canonical_dict.keys()): utils.log_and_raise_error( @@ -130,6 +144,29 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ConfigError, ) + # Check gdrive settings + elif config_dict["connection_method"] == "gdrive": + if not config_dict["gdrive_root_folder_id"]: + utils.log_and_raise_error( + "'gdrive_root_folder_id' is required if 'connection_method' " + "is 'gdrive'.", + ConfigError, + ) + + if not config_dict["gdrive_client_id"]: + utils.log_and_message( + "`gdrive_client_id` not found in config. default rlcone client will be used (slower)." + ) + + # Check AWS settings + elif config_dict["connection_method"] == "aws" and ( + not config_dict["aws_access_key_id"] or not config_dict["aws_region"] + ): + utils.log_and_raise_error( + "Both aws_access_key_id and aws_region must be present for AWS connection.", + ConfigError, + ) + # Initialise the local project folder utils.print_message_to_user( f"Making project folder at: {config_dict['local_path']}" diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 4d216576..5bf203af 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -120,16 +120,48 @@ def dump_to_file(self) -> None: def load_from_file(self) -> None: """Load a config dict saved at .yaml file. - Note this will not automatically check the configs are valid, - this requires calling self.check_dict_values_raise_on_fail(). + This will do a minimal backwards compatibility check and + add config keys to ensure backwards compatibility with new connection + methods added to Datashuttle. + + However, this will not automatically check the configs are valid, this + requires calling self.check_dict_values_raise_on_fail() """ with open(self.file_path) as config_file: config_dict = yaml.full_load(config_file) load_configs.convert_str_and_pathlib_paths(config_dict, "str_to_path") + self.update_config_for_backward_compatability_if_required(config_dict) + self.data = config_dict + def update_config_for_backward_compatability_if_required( + self, config_dict: Dict + ): + """Add keys introduced in later Datashuttle versions if they are missing.""" + canonical_config_keys_to_add = [ + "gdrive_client_id", + "gdrive_root_folder_id", + "aws_access_key_id", + "aws_region", + ] + + # All keys shall be missing for a backwards compatibility update + if not ( + all( + key in config_dict.keys() + for key in canonical_config_keys_to_add + ) + ): + assert not any( + key in config_dict.keys() + for key in canonical_config_keys_to_add + ) + + for key in canonical_config_keys_to_add: + config_dict[key] = None + # ------------------------------------------------------------------------- # Utils # ------------------------------------------------------------------------- @@ -186,6 +218,10 @@ def get_base_folder( ) -> Path: """Return the full base path for the given top-level folder. + If the connection method is `aws` or `drive`, the base path + might be `None` (e.g. if the Google Drive is the project folder). + In this case, the base path is ignored. + Parameters ---------- base @@ -202,7 +238,12 @@ def get_base_folder( if base == "local": base_folder = self["local_path"] / top_level_folder elif base == "central": - base_folder = self["central_path"] / top_level_folder + if self["central_path"] is None: + # This path should never be triggered for local-only + assert self["connection_method"] in ["aws", "gdrive"] + base_folder = Path(top_level_folder) + else: + base_folder = self["central_path"] / top_level_folder return base_folder @@ -299,8 +340,4 @@ def is_local_project(self): A project is 'local-only' if it has no `central_path` and `connection_method`. It can be used to make folders and validate, but not for transfer. """ - canonical_configs.raise_on_bad_local_only_project_configs(self) - - params_are_none = canonical_configs.local_only_configs_are_none(self) - - return all(params_are_none) + return self["connection_method"] is None diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index fcd17d42..f78c6434 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -18,6 +18,8 @@ ) if TYPE_CHECKING: + import subprocess + from datashuttle.utils.custom_types import ( DisplayMode, OverwriteExistingFiles, @@ -36,9 +38,11 @@ from datashuttle.configs.config_class import Configs from datashuttle.datashuttle_functions import _format_top_level_folder from datashuttle.utils import ( + aws, ds_logger, folders, formatting, + gdrive, getters, rclone, ssh, @@ -53,6 +57,7 @@ from datashuttle.utils.decorators import ( # noqa check_configs_set, check_is_not_local_project, + requires_aws_configs, requires_ssh_configs, ) @@ -826,6 +831,88 @@ def write_public_key(self, filepath: str) -> None: public.write(key.get_base64()) public.close() + # ------------------------------------------------------------------------- + # Google Drive + # ------------------------------------------------------------------------- + + @check_configs_set + def setup_google_drive_connection(self) -> None: + """Set up a connection to Google Drive using the provided credentials. + + Assumes `gdrive_root_folder_id` is set in configs. + + First, the user will be prompted to enter their Google Drive client + secret if `gdrive_client_id` is set in the configs. + + Next, the user will be asked if their machine has access to a browser. + If not, they will be prompted to input their service account file path. + + Next, with the provided credentials, the final setup will be done. This + opens up a browser if the user confirmed access to a browser. + """ + self._start_log( + "setup-google-drive-connection-to-central-server", + local_vars=locals(), + ) + + browser_available = gdrive.ask_user_for_browser(log=True) + + service_account_filepath = None + gdrive_client_secret = None + + if browser_available and self.cfg["gdrive_client_id"]: + gdrive_client_secret = gdrive.get_client_secret() + + elif not browser_available: + service_account_filepath = ( + gdrive.prompt_and_get_service_account_filepath( + log=True, + ) + ) + + process = self._setup_rclone_gdrive_config( + gdrive_client_secret, service_account_filepath + ) + + rclone.await_call_rclone_with_popen_raise_on_fail(process, log=True) + + rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) + + utils.log_and_message("Google Drive Connection Successful.") + + ds_logger.close_log_filehandler() + + # ------------------------------------------------------------------------- + # AWS S3 + # ------------------------------------------------------------------------- + + @requires_aws_configs + @check_configs_set + def setup_aws_connection(self) -> None: + """Set up a connection to AWS S3 buckets using the provided credentials. + + Assumes `aws_access_key_id` and `aws_region` are set in configs. + + First, the user will be prompted to input their AWS secret access key. + + Next, with the provided credentials, the final connection setup will be done. + """ + self._start_log( + "setup-aws-connection-to-central-server", + local_vars=locals(), + ) + + aws_secret_access_key = aws.get_aws_secret_access_key() + + self._setup_rclone_aws_config(aws_secret_access_key, log=True) + + rclone.check_successful_connection_and_raise_error_on_fail(self.cfg) + aws.raise_if_bucket_absent(self.cfg) + + utils.log_and_message("AWS Connection Successful.") + + ds_logger.close_log_filehandler() + # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- @@ -837,6 +924,10 @@ def make_config_file( connection_method: str | None = None, central_host_id: Optional[str] = None, central_host_username: Optional[str] = None, + gdrive_client_id: Optional[str] = None, + gdrive_root_folder_id: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_region: Optional[str] = None, ) -> None: """Initialize the configurations for datashuttle on the local machine. @@ -877,6 +968,26 @@ def make_config_file( username for which to log in to central host. e.g. ``"jziminski"`` + gdrive_client_id + The client ID used to authenticate with the Google Drive API via OAuth 2.0. + This is obtained from the Google Cloud Console when setting up API credentials. + e.g. "1234567890-abc123def456.apps.googleusercontent.com" + + gdrive_root_folder_id + The folder ID for the Google Drive folder to connect to. This can be copied + directly from your browser when on the folder in Google Drive. + e.g. 1eoAnopd2ZHOd87LgiPtgViFE7u3R9sSw + + aws_access_key_id + The AWS access key ID used to authenticate requests to AWS services. + This is part of your AWS credentials and can be generated via the AWS IAM console. + e.g. "AKIAIOSFODNN7EXAMPLE" + + aws_region + The AWS region in which your resources are located. + This determines the data center your requests are routed to. + e.g. "us-west-2" + """ self._start_log( "make-config-file", @@ -900,6 +1011,10 @@ def make_config_file( "connection_method": connection_method, "central_host_id": central_host_id, "central_host_username": central_host_username, + "gdrive_client_id": gdrive_client_id, + "gdrive_root_folder_id": gdrive_root_folder_id, + "aws_access_key_id": aws_access_key_id, + "aws_region": aws_region, }, ) @@ -1409,6 +1524,28 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: self.cfg.get_rclone_config_name("local_filesystem"), ) + def _setup_rclone_gdrive_config( + self, + gdrive_client_secret: str | None, + service_account_filepath: str | None, + ) -> subprocess.Popen: + return rclone.setup_rclone_config_for_gdrive( + self.cfg, + self.cfg.get_rclone_config_name("gdrive"), + gdrive_client_secret, + service_account_filepath, + ) + + def _setup_rclone_aws_config( + self, aws_secret_access_key: str, log: bool + ) -> None: + rclone.setup_rclone_config_for_aws( + self.cfg, + self.cfg.get_rclone_config_name("aws"), + aws_secret_access_key, + log=log, + ) + # Persistent settings # ------------------------------------------------------------------------- diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index de52087b..a7052547 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -64,6 +64,12 @@ SettingsScreen { GetHelpScreen { align: center middle; } +SetupGdriveScreen { + align: center middle; +} +SetupAwsScreen { + align: center middle; +} #get_help_label { align: center middle; text-align: center; @@ -114,6 +120,74 @@ MessageBox:light > #messagebox_top_container { align: center middle; } +#setup_gdrive_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} + +#gdrive_setup_messagebox_message_container { + height: 60%; + align: center middle; + overflow: hidden auto; + margin: 0 1; +} + +#gdrive_setup_messagebox_message { + text-align: center; + padding: 0 2; +} + +#setup_gdrive_ok_button { + align: center bottom; + height: 3; +} + +#setup_gdrive_cancel_button { + align: center bottom; +} + +#setup_gdrive_buttons_horizontal { + align: center middle; +} + +#setup_aws_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} + +#setup_aws_messagebox_message_container { + height: 70%; + align: center middle; + overflow: hidden auto; + margin: 0 1; +} + +#setup_aws_messagebox_message { + text-align: center; + padding: 0 2; +} + +#setup_aws_secret_access_key_input { + dock: bottom; +} + +#setup_aws_ok_button { + align: center bottom; + height: 3; +} + +#setup_aws_cancel_button { + align: center bottom; +} + +#setup_aws_buttons_horizontal { + align: center middle; +} + /* Configs Content ----------------------------------------------------------------- */ #configs_container { @@ -161,11 +235,19 @@ MessageBox:light > #messagebox_top_container { padding: 0 4 0 4; width: 26; color: $success; /* unsure about this */ + dock: left; +} + +#setup_buttons_container { + height: 100%; } -#configs_setup_ssh_connection_button { +#setup_buttons_container > Button { margin: 2 1 0 0; + dock: left; } + + #configs_go_to_project_screen_button { margin: 2 1 0 0; } @@ -204,6 +286,10 @@ MessageBox:light > #messagebox_top_container { padding: 0 0 2 0; } +#configs_aws_region_select { + width: 70%; +} + /* This Horizontal wrapper container is necessary to make the config label and button align center */ diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index cc7c91ef..05dd8403 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional if TYPE_CHECKING: + import subprocess + import paramiko from datashuttle.configs.config_class import Configs @@ -11,7 +13,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import ssh +from datashuttle.utils import aws, rclone, ssh, utils class Interface: @@ -36,6 +38,9 @@ def __init__(self) -> None: self.name_templates: Dict = {} self.tui_settings: Dict = {} + self.google_drive_rclone_setup_process: subprocess.Popen | None = None + self.gdrive_setup_process_killed: bool = False + def select_existing_project(self, project_name: str) -> InterfaceOutput: """Load an existing project into `self.project`. @@ -344,7 +349,7 @@ def transfer_custom_selection( except BaseException as e: return False, str(e) - # Setup SSH + # Name templates # ---------------------------------------------------------------------------------- def get_name_templates(self) -> Dict: @@ -499,3 +504,85 @@ def setup_key_pair_and_rclone_config( except BaseException as e: return False, str(e) + + # Setup Google Drive + # ---------------------------------------------------------------------------------- + + def setup_google_drive_connection( + self, + gdrive_client_secret: Optional[str] = None, + service_account_filepath: Optional[str] = None, + ) -> InterfaceOutput: + """Try to set up and validate connection to Google Drive. + + This is done by running the rclone setup function which returns a + subprocess.Popen object. The process object is stored in + `self.google_drive_rclone_setup_process` to allow for termination + of the process if needed. The `self.gdrive_setup_process_killed` + flag is set to false to signal normal operation. The process is then + awaited to ensure it completes successfully. If the process is killed + manually, the `self.gdrive_setup_process_killed` flag is set to True + to prevent raising an error when the process is killed. + """ + try: + process = self.project._setup_rclone_gdrive_config( + gdrive_client_secret, service_account_filepath + ) + self.google_drive_rclone_setup_process = process + self.gdrive_setup_process_killed = False + + self.await_successful_gdrive_connection_setup_raise_on_fail( + process + ) + + return True, None + except BaseException as e: + return False, str(e) + + def terminate_google_drive_setup(self) -> None: + """Terminate rclone setup for google drive by killing the rclone process.""" + assert self.google_drive_rclone_setup_process is not None + + process = self.google_drive_rclone_setup_process + + # Check if the process is still running + if process.poll() is None: + self.gdrive_setup_process_killed = True + process.kill() + + def await_successful_gdrive_connection_setup_raise_on_fail( + self, process: subprocess.Popen + ): + """Wait for rclone setup for google drive to finish and verify successful connection. + + The `self.gdrive_setup_process_killed` flag helps prevent raising errors in case the + process was killed manually. + """ + stdout, stderr = process.communicate() + + if not self.gdrive_setup_process_killed: + if process.returncode != 0: + utils.raise_error(stderr.decode("utf-8"), ConnectionError) + + rclone.check_successful_connection_and_raise_error_on_fail( + self.project.cfg + ) + + # Setup AWS + # ---------------------------------------------------------------------------------- + + def setup_aws_connection( + self, aws_secret_access_key: str + ) -> InterfaceOutput: + """Set up the Amazon Web Service connection.""" + try: + self.project._setup_rclone_aws_config( + aws_secret_access_key, log=False + ) + rclone.check_successful_connection_and_raise_error_on_fail( + self.project.cfg + ) + aws.raise_if_bucket_absent(self.project.cfg) + return True, None + except BaseException as e: + return False, str(e) diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py new file mode 100644 index 00000000..8527e1dd --- /dev/null +++ b/datashuttle/tui/screens/setup_aws.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + + from datashuttle.tui.interface import Interface + +from textual.containers import Container, Horizontal, Vertical +from textual.screen import ModalScreen +from textual.widgets import Button, Input, Static + + +class SetupAwsScreen(ModalScreen): + """Dialog window that sets up connection to an Amazon Web Service S3 bucket. + + This asks the user for confirmation to proceed with the setup, + and then prompts the user for the AWS Secret Access Key. + + The secret access key is then used to set up rclone config for AWS S3. + """ + + def __init__(self, interface: Interface) -> None: + """Initialise the SetupAwsScreen.""" + super(SetupAwsScreen, self).__init__() + + self.interface = interface + self.stage = 0 + + def compose(self) -> ComposeResult: + """Set widgets on the SetupAwsScreen.""" + yield Container( + Vertical( + Static( + "Ready to setup AWS connection. Press OK to proceed", + id="setup_aws_messagebox_message", + ), + Input(password=True, id="setup_aws_secret_access_key_input"), + id="setup_aws_messagebox_message_container", + ), + Horizontal( + Button("OK", id="setup_aws_ok_button"), + Button("Cancel", id="setup_aws_cancel_button"), + id="setup_aws_buttons_horizontal", + ), + id="setup_aws_screen_container", + ) + + def on_mount(self) -> None: + """Update widgets immediately after mounting.""" + self.query_one("#setup_aws_secret_access_key_input").visible = False + + def on_button_pressed(self, event: Button.Pressed) -> None: + """Handle button press on the screen.""" + if event.button.id == "setup_aws_cancel_button": + self.dismiss() + + if event.button.id == "setup_aws_ok_button": + if self.stage == 0: + self.prompt_user_for_aws_secret_access_key() + + elif self.stage == 1: + self.use_secret_access_key_to_setup_aws_connection() + + elif self.stage == 2: + self.dismiss() + + def prompt_user_for_aws_secret_access_key(self) -> None: + """Set widgets for user to input AWS key.""" + message = "Please Enter your AWS Secret Access Key" + + self.query_one("#setup_aws_messagebox_message").update(message) + self.query_one("#setup_aws_secret_access_key_input").visible = True + + self.stage += 1 + + def use_secret_access_key_to_setup_aws_connection(self) -> None: + """Set up the AWS connection and inform user of success or failure.""" + secret_access_key = self.query_one( + "#setup_aws_secret_access_key_input" + ).value + + success, output = self.interface.setup_aws_connection( + secret_access_key + ) + + if success: + message = "AWS Connection Successful!" + self.query_one( + "#setup_aws_secret_access_key_input" + ).visible = False + + else: + message = ( + f"AWS setup failed. Please check your configs and secret access key" + f"\n\n Traceback: {output}" + ) + self.query_one( + "#setup_aws_secret_access_key_input" + ).disabled = True + + self.query_one("#setup_aws_ok_button").label = "Finish" + self.query_one("#setup_aws_messagebox_message").update(message) + self.query_one("#setup_aws_cancel_button").disabled = True + self.stage += 1 diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py new file mode 100644 index 00000000..3f04e0c3 --- /dev/null +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from textual.app import ComposeResult + from textual.worker import Worker + + from datashuttle.tui.interface import Interface + from datashuttle.utils.custom_types import InterfaceOutput + +from textual import work +from textual.containers import Container, Horizontal, Vertical +from textual.screen import ModalScreen +from textual.widgets import ( + Button, + Input, + Static, +) + + +class SetupGdriveScreen(ModalScreen): + """Dialog window that sets up a Google Drive connection. + + If the config contains a "gdrive_client_id", the user is prompted + to enter a client secret. If the user has access to a browser, a Google Drive + authentication page will open. Otherwise, the user is asked to enter path to a + service account file. + """ + + def __init__(self, interface: Interface) -> None: + """Initialise the SetupGdriveScreen.""" + super(SetupGdriveScreen, self).__init__() + + self.interface = interface + self.stage: int = 0 + self.setup_worker: Worker | None = None + self.is_browser_available: bool = True + + # For handling credential inputs + self.input_box: Input = Input( + id="setup_gdrive_generic_input_box", + placeholder="Enter value here", + ) + self.enter_button = Button("Enter", id="setup_gdrive_enter_button") + + def compose(self) -> ComposeResult: + """Add widgets to the SetupGdriveScreen.""" + yield Container( + Vertical( + Static( + "Ready to setup Google Drive. Press OK to proceed", + id="gdrive_setup_messagebox_message", + ), + id="gdrive_setup_messagebox_message_container", + ), + Horizontal( + Button("OK", id="setup_gdrive_ok_button"), + Button("Cancel", id="setup_gdrive_cancel_button"), + id="setup_gdrive_buttons_horizontal", + ), + id="setup_gdrive_screen_container", + ) + + def on_button_pressed(self, event: Button.Pressed) -> None: + """Handle a button press on the screen. + + This dialog window operates using 6 buttons: + + 1) "ok" button : Starts the connection setup process. + + 2) "yes" button : A "yes" answer to the availability of browser question. On click, + if "gdrive_client_id" is present in configs, the user is asked for client secret + and proceeds to a browser authentication. + + 3) "no" button : A "no" answer to the availability of browser question. On click, + prompts the user to enter path to their service account file. + + 4) "enter" button : To enter the client secret or service account file path. + + 5) "finish" button : To finish the setup. + + 6) "cancel" button : To cancel the setup at any step before completion. + """ + if ( + event.button.id == "setup_gdrive_cancel_button" + or event.button.id == "setup_gdrive_finish_button" + ): + # see setup_gdrive_connection_and_update_ui() + if self.setup_worker and self.setup_worker.is_running: + self.setup_worker.cancel() # fix + self.interface.terminate_google_drive_setup() + self.dismiss() + + elif event.button.id == "setup_gdrive_ok_button": + self.ask_user_for_browser() + + elif event.button.id == "setup_gdrive_yes_button": + self.remove_yes_no_buttons() + + if self.interface.project.cfg["gdrive_client_id"]: + self.ask_user_for_gdrive_client_secret() + else: + self.open_browser_and_setup_gdrive_connection() + + elif event.button.id == "setup_gdrive_no_button": + self.is_browser_available = False + self.remove_yes_no_buttons() + self.prompt_user_for_service_account_filepath() + + elif event.button.id == "setup_gdrive_enter_button": + if self.is_browser_available: + gdrive_client_secret = ( + self.input_box.value.strip() + if self.input_box.value.strip() + else None + ) + self.open_browser_and_setup_gdrive_connection( + gdrive_client_secret + ) + else: + service_account_filepath = ( + self.input_box.value.strip() + if self.input_box.value.strip() + else None + ) + self.setup_gdrive_connection_using_service_account( + service_account_filepath + ) + + def ask_user_for_browser(self) -> None: + """Ask the user if their machine has access to a browser.""" + message = ( + "Are you running Datashuttle on a machine " + "that can open a web browser?" + ) + self.update_message_box_message(message) + + self.query_one("#setup_gdrive_ok_button").remove() + + # Mount the Yes and No buttons + yes_button = Button("Yes", id="setup_gdrive_yes_button") + no_button = Button("No", id="setup_gdrive_no_button") + + self.query_one("#setup_gdrive_buttons_horizontal").mount( + yes_button, no_button, before="#setup_gdrive_cancel_button" + ) + + def ask_user_for_gdrive_client_secret(self) -> None: + """Ask the user for Google Drive client secret. + + Only called if the datashuttle config has a `gdrive_client_id`. + """ + message = ( + "Please provide the client secret for Google Drive. " + "You can find it in your Google Cloud Console." + ) + self.update_message_box_message(message) + + self.query_one("#setup_gdrive_screen_container").mount( + self.enter_button, before="#setup_gdrive_cancel_button" + ) + + self.mount_input_box_before_buttons(is_password=True) + + def open_browser_and_setup_gdrive_connection( + self, gdrive_client_secret: Optional[str] = None + ) -> None: + """Set up Google Drive when the user has a browser. + + Starts an asyncio task to setup Google Drive + connection and updates the UI with success/failure. + + The connection setup is asynchronous so that the user is able to + cancel the setup if anything goes wrong without quitting datashuttle altogether. + """ + message = "Please authenticate through browser." + self.update_message_box_message(message) + + asyncio.create_task( + self.setup_gdrive_connection_and_update_ui( + gdrive_client_secret=gdrive_client_secret + ) + ) + + def prompt_user_for_service_account_filepath(self) -> None: + """Set up widgets and prompt user for their service account file path for browser-less connection.""" + message = "Please enter your service account file path." + + self.update_message_box_message(message) + + self.query_one("#setup_gdrive_buttons_horizontal").mount( + self.enter_button, before="#setup_gdrive_cancel_button" + ) + self.mount_input_box_before_buttons() + + def setup_gdrive_connection_using_service_account( + self, service_account_filepath: Optional[str] = None + ) -> None: + """Set up the Google Drive connection using service account and show success message.""" + message = "Setting up connection." + self.update_message_box_message(message) + + asyncio.create_task( + self.setup_gdrive_connection_and_update_ui( + service_account_filepath=service_account_filepath + ) + ) + + async def setup_gdrive_connection_and_update_ui( + self, + gdrive_client_secret: Optional[str] = None, + service_account_filepath: Optional[str] = None, + ) -> None: + """Start the Google Drive connection setup in a separate thread. + + The setup is run in a worker thread to avoid blocking the UI so that + the user can cancel the setup if needed. This function starts the worker + thread for google drive setup, sets `self.setup_worker` to the worker and + awaits the worker to finish. After completion, it displays a + success / failure screen. The setup on the lower level is a bit complicated. + The worker thread runs the `setup_google_drive_connection` method of the + `Interface` class which spawns an rclone process to set up the connection. + The rclone process object is stored in the `Interface` class to handle closing + the process as the thread does not kill the process itself upon cancellation and + the process is awaited ensure that the process finishes and any raised errors are caught. + Therefore, the worker thread thread and the rclone process are separately cancelled + when the user presses the cancel button. (see `on_button_pressed`) + """ + self.input_box.disabled = True + self.enter_button.disabled = True + + worker = self.setup_gdrive_connection( + gdrive_client_secret, service_account_filepath + ) + self.setup_worker = worker + if worker.is_running: + await worker.wait() + + success, output = worker.result + if success: + self.show_finish_screen() + else: + self.input_box.disabled = False + self.enter_button.disabled = False + self.display_failed(output) + + @work(exclusive=True, thread=True) + def setup_gdrive_connection( + self, + gdrive_client_secret: Optional[str] = None, + service_account_filepath: Optional[str] = None, + ) -> Worker[InterfaceOutput]: + """Authenticate the Google Drive connection. + + This function runs in a worker thread to set up Google Drive connection. + If the user had access to a browser, the underlying rclone commands called + by this function are responsible for opening google's auth page to authenticate + with Google Drive. + """ + success, output = self.interface.setup_google_drive_connection( + gdrive_client_secret, service_account_filepath + ) + return success, output + + # ---------------------------------------------------------------------------------- + # UI Update Methods + # ---------------------------------------------------------------------------------- + + def show_finish_screen(self) -> None: + """Show the final screen after successful set up.""" + message = "Setup Complete!" + self.query_one("#setup_gdrive_cancel_button").remove() + + self.update_message_box_message(message) + self.query_one("#setup_gdrive_buttons_horizontal").mount( + Button("Finish", id="setup_gdrive_finish_button") + ) + + def display_failed(self, output) -> None: + """Update the message box indicating the set up failed.""" + message = ( + f"Google Drive setup failed. Please check your credentials" + f"\n\n Traceback: {output}" + ) + self.update_message_box_message(message) + + def update_message_box_message(self, message: str) -> None: + """Update the text message displayed to the user.""" + self.query_one("#gdrive_setup_messagebox_message").update(message) + + def mount_input_box_before_buttons( + self, is_password: bool = False + ) -> None: + """Add the Input box to the screen. + + This Input may be used for entering connection details or a password. + """ + self.input_box.password = is_password + self.input_box.styles.dock = "bottom" + self.query_one("#gdrive_setup_messagebox_message_container").mount( + self.input_box, after="#gdrive_setup_messagebox_message" + ) + + def remove_yes_no_buttons(self) -> None: + """Remove yes and no buttons.""" + self.query_one("#setup_gdrive_yes_button").remove() + self.query_one("#setup_gdrive_no_button").remove() diff --git a/datashuttle/tui/screens/setup_ssh.py b/datashuttle/tui/screens/setup_ssh.py index e0291569..cd944e95 100644 --- a/datashuttle/tui/screens/setup_ssh.py +++ b/datashuttle/tui/screens/setup_ssh.py @@ -22,10 +22,6 @@ class SetupSshScreen(ModalScreen): This asks to confirm the central hostkey, and takes password to setup SSH key pair. Under the hood uses `project.setup_ssh_connection()`. - - This is the one instance in which it is not possible for - the TUI to nearly wrap the API, because the logic flow is - broken up requiring user input (accept hostkey and input password). """ def __init__(self, interface: Interface) -> None: diff --git a/datashuttle/tui/shared/configs_content.py b/datashuttle/tui/shared/configs_content.py index be3b3444..88ab752b 100644 --- a/datashuttle/tui/shared/configs_content.py +++ b/datashuttle/tui/shared/configs_content.py @@ -20,12 +20,20 @@ Label, RadioButton, RadioSet, + Select, Static, ) +from datashuttle.configs.aws_regions import get_aws_regions_list +from datashuttle.configs.canonical_configs import get_connection_methods_list from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface -from datashuttle.tui.screens import modal_dialogs, setup_ssh +from datashuttle.tui.screens import ( + modal_dialogs, + setup_aws, + setup_gdrive, + setup_ssh, +) from datashuttle.tui.tooltips import get_tooltip @@ -116,12 +124,43 @@ def compose(self) -> ComposeResult: ), ] + self.config_gdrive_widgets = [ + Label("Root Folder ID", id="configs_gdrive_root_folder_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="Google Drive Root Folder ID", + id="configs_gdrive_root_folder_id", + ), + Label("Client ID (Optional)", id="configs_gdrive_client_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="Google Drive Client ID (Optional)", + id="configs_gdrive_client_id_input", + ), + ] + + self.config_aws_widgets = [ + Label("AWS Access Key ID", id="configs_aws_access_key_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="AWS Access Key ID eg. EJIBCLSIP2K2PQK3CDON", + id="configs_aws_access_key_id_input", + ), + Label("AWS S3 Region", id="configs_aws_region_label"), + Select( + ((region, region) for region in get_aws_regions_list()), + id="configs_aws_region_select", + ), + ] + config_screen_widgets = [ Label("Local Path", id="configs_local_path_label"), Horizontal( ClickableInput( self.parent_class.mainwindow, - placeholder=f"e.g. {self.get_platform_dependent_example_paths('local')}", + placeholder=self.get_platform_dependent_example_paths( + "local", "local_filesystem" + ), id="configs_local_path_input", ), Button("Select", id="configs_local_path_select_button"), @@ -129,23 +168,39 @@ def compose(self) -> ComposeResult: ), Label("Connection Method", id="configs_connect_method_label"), RadioSet( + RadioButton( + "No connection (local only)", + id="configs_local_only_radiobutton", + ), RadioButton( "Local Filesystem", - id="configs_local_filesystem_radiobutton", + id=self.radiobutton_id_from_connection_method( + "local_filesystem" + ), ), - RadioButton("SSH", id="configs_ssh_radiobutton"), RadioButton( - "No connection (local only)", - id="configs_local_only_radiobutton", + "SSH", id=self.radiobutton_id_from_connection_method("ssh") + ), + RadioButton( + "Google Drive", + id=self.radiobutton_id_from_connection_method("gdrive"), + ), + RadioButton( + "AWS S3", + id=self.radiobutton_id_from_connection_method("aws"), ), id="configs_connect_method_radioset", ), *self.config_ssh_widgets, + *self.config_gdrive_widgets, + *self.config_aws_widgets, Label("Central Path", id="configs_central_path_label"), Horizontal( ClickableInput( self.parent_class.mainwindow, - placeholder=f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}", + placeholder=self.get_platform_dependent_example_paths( + "central", "local_filesystem" + ), id="configs_central_path_input", ), Button("Select", id="configs_central_path_select_button"), @@ -153,9 +208,12 @@ def compose(self) -> ComposeResult: ), Horizontal( Button("Save", id="configs_save_configs_button"), - Button( - "Setup SSH Connection", - id="configs_setup_ssh_connection_button", + Horizontal( + Button( + "Setup Button", + id="configs_setup_connection_button", + ), + id="setup_buttons_container", ), # Below button is always hidden when accessing # configs from project manager screen @@ -211,14 +269,17 @@ def on_mount(self) -> None: self.query_one("#configs_go_to_project_screen_button").visible = False if self.interface: self.fill_widgets_with_project_configs() + self.setup_widgets_to_display( + connection_method=self.interface.get_configs()[ + "connection_method" + ] + ) else: self.query_one( "#configs_local_filesystem_radiobutton" ).value = True - self.switch_ssh_widgets_display(display_ssh=False) - self.query_one( - "#configs_setup_ssh_connection_button" - ).visible = False + + self.setup_widgets_to_display(connection_method="local_filesystem") # Setup tooltips if not self.interface: @@ -230,12 +291,6 @@ def on_mount(self) -> None: self.query_one("#configs_local_filesystem_radiobutton").value is True ) - self.set_central_path_input_tooltip(display_ssh=False) - else: - display_ssh = ( - self.interface.project.cfg["connection_method"] == "ssh" - ) - self.set_central_path_input_tooltip(display_ssh) for id in [ "#configs_local_path_input", @@ -245,6 +300,8 @@ def on_mount(self) -> None: "#configs_local_only_radiobutton", "#configs_central_host_username_input", "#configs_central_host_id_input", + "#configs_gdrive_client_id_input", + "#configs_gdrive_root_folder_id", ]: self.query_one(id).tooltip = get_tooltip(id) @@ -258,43 +315,64 @@ def on_radio_set_changed(self, event: RadioSet.Changed) -> None: disabled. """ label = str(event.pressed.label) + radiobutton_id = event.pressed.id + assert label in [ "SSH", "Local Filesystem", "No connection (local only)", + "Google Drive", + "AWS S3", ], "Unexpected label." - if label == "No connection (local only)": - self.query_one("#configs_central_path_input").value = "" - self.query_one("#configs_central_path_input").disabled = True - self.query_one( - "#configs_central_path_select_button" - ).disabled = True - display_ssh = False - else: - self.query_one("#configs_central_path_input").disabled = False - self.query_one( - "#configs_central_path_select_button" - ).disabled = False - display_ssh = True if label == "SSH" else False + connection_method = self.connection_method_from_radiobutton_id( + radiobutton_id + ) - self.switch_ssh_widgets_display(display_ssh) - self.set_central_path_input_tooltip(display_ssh) + self.setup_widgets_to_display(connection_method) - def set_central_path_input_tooltip(self, display_ssh: bool) -> None: + self.set_central_path_input_tooltip(connection_method) + + def radiobutton_id_from_connection_method( + self, connection_method: str + ) -> str: + """Create a canonical radiobutton textual ID from the connection method.""" + return f"configs_{connection_method}_radiobutton" + + def connection_method_from_radiobutton_id( + self, radiobutton_id: str + ) -> str | None: + """Convert back from radiobutton Textual ID to connection method.""" + assert radiobutton_id.startswith("configs_") + assert radiobutton_id.endswith("_radiobutton") + + connection_string = radiobutton_id[ + len("configs_") : -len("_radiobutton") + ] + return ( + connection_string + if connection_string in get_connection_methods_list() + else None + ) + + def set_central_path_input_tooltip( + self, connection_method: str | None + ) -> None: """Set tooltip depending on whether connection method is SSH or local filesystem.""" - id = "#configs_central_path_input" - if display_ssh: - self.query_one(id).tooltip = get_tooltip( - "config_central_path_input_mode-ssh" + if connection_method is None: + tooltip = get_tooltip( + "config_central_path_input_mode-local_filesystem" ) else: - self.query_one(id).tooltip = get_tooltip( - "config_central_path_input_mode-local_filesystem" + tooltip = get_tooltip( + f"config_central_path_input_mode-{connection_method}" ) + self.query_one("#configs_central_path_input").tooltip = tooltip def get_platform_dependent_example_paths( - self, local_or_central: Literal["local", "central"], ssh: bool = False + self, + local_or_central: Literal["local", "central"], + connection_method: str, ) -> str: """Get example paths for the local or central Inputs depending on operating system. @@ -303,63 +381,31 @@ def get_platform_dependent_example_paths( local_or_central The "local" or "central" input to fill. - ssh - If the user has selected SSH (which changes the central input). + connection_method + Connection method e.g. "local_filesystem" """ assert local_or_central in ["local", "central"] # Handle the ssh central case separately # because it is always the same - if local_or_central == "central" and ssh: - example_path = "/nfs/path_on_server/myprojects/central" + if ( + local_or_central == "central" + and connection_method != "local_filesystem" + ): + if connection_method == "ssh": + example_path = "e.g. /nfs/path_on_server/myprojects/central" + elif connection_method in ["aws", "gdrive"]: + example_path = "" + else: if platform.system() == "Windows": - example_path = rf"C:\path\to\{local_or_central}\my_projects\my_first_project" + example_path = rf"e.g. C:\path\to\{local_or_central}\my_projects\my_first_project" else: - example_path = ( - f"/path/to/{local_or_central}/my_projects/my_first_project" - ) + example_path = f"e.g. /path/to/{local_or_central}/my_projects/my_first_project" return example_path - def switch_ssh_widgets_display(self, display_ssh: bool) -> None: - """Show or hide SSH-related configs. - - This is based on whether the current `connection_method` - widget is "ssh" or "local_filesystem". - - Parameters - ---------- - display_ssh - If `True`, display the SSH-related widgets. - - """ - for widget in self.config_ssh_widgets: - widget.display = display_ssh - - self.query_one( - "#configs_central_path_select_button" - ).display = not display_ssh - - if self.interface is None: - self.query_one( - "#configs_setup_ssh_connection_button" - ).visible = False - else: - self.query_one( - "#configs_setup_ssh_connection_button" - ).visible = display_ssh - - if not self.query_one("#configs_central_path_input").value: - if display_ssh: - placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=True)}" - else: - placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}" - self.query_one( - "#configs_central_path_input" - ).placeholder = placeholder - def on_button_pressed(self, event: Button.Pressed) -> None: """Handle a button press event. @@ -372,8 +418,28 @@ def on_button_pressed(self, event: Button.Pressed) -> None: else: self.setup_configs_for_an_existing_project() - elif event.button.id == "configs_setup_ssh_connection_button": - self.setup_ssh_connection() + elif event.button.id == "configs_setup_connection_button": + assert self.interface is not None, ( + "type narrow flexible `interface`" + ) + + connection_method = self.interface.get_configs()[ + "connection_method" + ] + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + if connection_method == "ssh": + self.setup_ssh_connection() + elif connection_method == "gdrive": + self.setup_gdrive_connection() + elif connection_method == "aws": + self.setup_aws_connection() elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -425,20 +491,29 @@ def handle_input_fill_from_select_directory( ).value = path_.as_posix() def setup_ssh_connection(self) -> None: - """Set up the `SetupSshScreen` screen.""" + """Run the SSH set up in a new screen.""" assert self.interface is not None, "type narrow flexible `interface`" - if not self.widget_configs_match_saved_configs(): - self.parent_class.mainwindow.show_modal_error_dialog( - "The values set above must equal the datashuttle settings. " - "Either press 'Save' or reload this page." - ) - return - self.parent_class.mainwindow.push_screen( setup_ssh.SetupSshScreen(self.interface) ) + def setup_gdrive_connection(self) -> None: + """Run the Google Drive set up in a new screen.""" + assert self.interface is not None, "type narrow flexible `interface`" + + self.parent_class.mainwindow.push_screen( + setup_gdrive.SetupGdriveScreen(self.interface) + ) + + def setup_aws_connection(self) -> None: + """Run the AWS set up in a new screen.""" + assert self.interface is not None, "type narrow flexible `interface`" + + self.parent_class.mainwindow.push_screen( + setup_aws.SetupAwsScreen(self.interface) + ) + def widget_configs_match_saved_configs(self): """Ensure configs as set on screen match those stored in the project object. @@ -489,22 +564,29 @@ def setup_configs_for_a_new_project(self) -> None: "#configs_go_to_project_screen_button" ).visible = True + # A message template to display custom message to user according to the chosen connection method + message_template = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the {method_name} connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) + # Could not find a neater way to combine the push screen # while initiating the callback in one case but not the other. - if cfg_kwargs["connection_method"] == "ssh": - self.query_one( - "#configs_setup_ssh_connection_button" - ).visible = True - self.query_one( - "#configs_setup_ssh_connection_button" - ).disabled = False + connection_method = cfg_kwargs["connection_method"] - message = ( - "A datashuttle project has now been created.\n\n " - "Next, setup the SSH connection. Once complete, navigate to the " - "'Main Menu' and proceed to the project page, where you will be " - "able to create and transfer project folders." - ) + # To trigger the appearance of "Setup connection" button + self.setup_widgets_to_display(connection_method) + + if connection_method == "ssh": + message = message_template.format(method_name="SSH") + + elif connection_method == "gdrive": + message = message_template.format(method_name="Google Drive") + + elif connection_method == "aws": + message = message_template.format(method_name="AWS") else: message = ( @@ -534,7 +616,6 @@ def setup_configs_for_an_existing_project(self) -> None: # Handle the edge case where connection method is changed after # saving on the 'Make New Project' screen. - self.query_one("#configs_setup_ssh_connection_button").visible = True cfg_kwargs = self.get_datashuttle_inputs_from_widgets() @@ -549,6 +630,8 @@ def setup_configs_for_an_existing_project(self) -> None: ), lambda unused: self.post_message(self.ConfigsSaved()), ) + # To trigger the appearance of "Setup connection" button + self.setup_widgets_to_display(cfg_kwargs["connection_method"]) else: self.parent_class.mainwindow.show_modal_error_dialog(output) @@ -559,24 +642,15 @@ def fill_widgets_with_project_configs(self) -> None: widgets with the current project configs. This in some instances requires recasting to a new type of changing the value. - In the case of the `connection_method` widget, the associated - "ssh" widgets are hidden / displayed based on the current setting, - in `self.switch_ssh_widgets_display()`. + In the case of the `connection_method` widget, the associated connection + method radio button is hidden / displayed based on the current settings. + This change of radio button triggers `on_radio_set_changed` which displays + the appropriate connection method widgets. """ assert self.interface is not None, "type narrow flexible `interface`" cfg_to_load = self.interface.get_textual_compatible_project_configs() - # Local Path - input = self.query_one("#configs_local_path_input") - input.value = cfg_to_load["local_path"] - - # Central Path - input = self.query_one("#configs_central_path_input") - input.value = ( - cfg_to_load["central_path"] if cfg_to_load["central_path"] else "" - ) - # Connection Method # Make a dict of radiobutton: is on bool to easily find # how to set radiobuttons and associated configs @@ -586,6 +660,10 @@ def fill_widgets_with_project_configs(self) -> None: cfg_to_load["connection_method"] == "ssh", "configs_local_filesystem_radiobutton": cfg_to_load["connection_method"] == "local_filesystem", + "configs_gdrive_radiobutton": + cfg_to_load["connection_method"] == "gdrive", + "configs_aws_radiobutton": + cfg_to_load["connection_method"] == "aws", "configs_local_only_radiobutton": cfg_to_load["connection_method"] is None, } @@ -594,8 +672,25 @@ def fill_widgets_with_project_configs(self) -> None: for id, value in what_radiobuton_is_on.items(): self.query_one(f"#{id}").value = value - self.switch_ssh_widgets_display( - display_ssh=what_radiobuton_is_on["configs_ssh_radiobutton"] + self.fill_inputs_with_project_configs() + + def fill_inputs_with_project_configs(self) -> None: + """Fill the input widgets with the current project configs. + + It is used while setting up widgets for the project while mounting the current tab. + """ + assert self.interface is not None, "type narrow flexible `interface`" + + cfg_to_load = self.interface.get_textual_compatible_project_configs() + + # Local Path + input = self.query_one("#configs_local_path_input") + input.value = cfg_to_load["local_path"] + + # Central Path + input = self.query_one("#configs_central_path_input") + input.value = ( + cfg_to_load["central_path"] if cfg_to_load["central_path"] else "" ) # Central Host ID @@ -616,6 +711,137 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value + # Google Drive Client ID + input = self.query_one("#configs_gdrive_client_id_input") + value = ( + "" + if cfg_to_load["gdrive_client_id"] is None + else cfg_to_load["gdrive_client_id"] + ) + input.value = value + + # Google Drive Root Folder ID + input = self.query_one("#configs_gdrive_root_folder_id") + value = ( + "" + if cfg_to_load["gdrive_root_folder_id"] is None + else cfg_to_load["gdrive_root_folder_id"] + ) + input.value = value + + # AWS Access Key ID + input = self.query_one("#configs_aws_access_key_id_input") + value = ( + "" + if cfg_to_load["aws_access_key_id"] is None + else cfg_to_load["aws_access_key_id"] + ) + input.value = value + + # AWS S3 Region + select = self.query_one("#configs_aws_region_select") + value = ( + Select.BLANK + if cfg_to_load["aws_region"] is None + else cfg_to_load["aws_region"] + ) + select.value = value + + def setup_widgets_to_display(self, connection_method: str | None) -> None: + """Set up widgets to display based on the chosen `connection_method` on the radiobutton. + + The widgets pertaining to the chosen connection method will be displayed. + This is done by dedicated functions for each connection method + which display widgets on receiving a `True` flag. + + Also, this function handles other TUI changes like displaying "setup connection" + button, disabling central path input in a local only project, etc. + + Called on mount, on radiobuttons' switch and upon saving project configs. + """ + if connection_method: + assert connection_method in get_connection_methods_list(), ( + "Unexpected Connection Method" + ) + + # Connection specific widgets + connection_widget_display_functions = { + "ssh": self.config_ssh_widgets, + "gdrive": self.config_gdrive_widgets, + "aws": self.config_aws_widgets, + } + + for ( + name, + connection_widgets, + ) in connection_widget_display_functions.items(): + for widget in connection_widgets: + widget.display = connection_method == name + + has_connection_method = connection_method is not None + + # Central Path Input + self.query_one( + "#configs_central_path_input" + ).disabled = not has_connection_method + self.query_one( + "#configs_central_path_select_button" + ).disabled = not has_connection_method + + # Central Path Input Placeholder + if connection_method is None: + self.query_one("#configs_central_path_input").value = "" + self.query_one("#configs_central_path_input").placeholder = "" + else: + placeholder = self.get_platform_dependent_example_paths( + "central", + connection_method, + ) + self.query_one( + "#configs_central_path_input" + ).placeholder = placeholder + + # Central Path Label + central_path_label = self.query_one("#configs_central_path_label") + if connection_method in ["gdrive", "aws"]: + central_path_label.update(content="Central Path (Optional)") + else: + central_path_label.update(content="Central Path") + + # Central Path Select Button + show_central_path_select = connection_method not in [ + "ssh", + "aws", + "gdrive", + ] + self.query_one( + "#configs_central_path_select_button" + ).display = show_central_path_select + + # fmt: off + # Setup connection button + setup_connection_button = self.query_one( + "#configs_setup_connection_button" + ) + + if ( + not connection_method + or connection_method == "local_filesystem" + or not self.interface + or connection_method != self.interface.get_configs()["connection_method"] + ): + setup_connection_button.visible = False + # fmt: on + else: + setup_connection_button.visible = True + + if connection_method == "ssh": + setup_connection_button.label = "Setup SSH Connection" + elif connection_method == "gdrive": + setup_connection_button.label = "Setup Google Drive Connection" + elif connection_method == "aws": + setup_connection_button.label = "Setup AWS Connection" + def get_datashuttle_inputs_from_widgets(self) -> Dict: """Get the configs to pass to `make_config_file()` from the current TUI settings.""" cfg_kwargs: Dict[str, Any] = {} @@ -632,30 +858,68 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: else: cfg_kwargs["central_path"] = Path(central_path_value) - if self.query_one("#configs_ssh_radiobutton").value: - connection_method = "ssh" + for id in [ + "configs_local_filesystem_radiobutton", + "configs_ssh_radiobutton", + "configs_gdrive_radiobutton", + "configs_aws_radiobutton", + "configs_local_only_radiobutton", + ]: + if self.query_one("#" + id).value: + connection_method = self.connection_method_from_radiobutton_id( + id + ) + break - elif self.query_one("#configs_local_filesystem_radiobutton").value: - connection_method = "local_filesystem" + cfg_kwargs["connection_method"] = connection_method - elif self.query_one("#configs_local_only_radiobutton").value: - connection_method = None + # SSH specific + if connection_method == "ssh": + cfg_kwargs["central_host_id"] = ( + self.get_config_value_from_input_value( + "#configs_central_host_id_input" + ) + ) - cfg_kwargs["connection_method"] = connection_method + cfg_kwargs["central_host_username"] = ( + self.get_config_value_from_input_value( + "#configs_central_host_username_input" + ) + ) - central_host_id = self.query_one( - "#configs_central_host_id_input" - ).value - cfg_kwargs["central_host_id"] = ( - None if central_host_id == "" else central_host_id - ) + # Google Drive specific + elif connection_method == "gdrive": + cfg_kwargs["gdrive_client_id"] = ( + self.get_config_value_from_input_value( + "#configs_gdrive_client_id_input" + ) + ) - central_host_username = self.query_one( - "#configs_central_host_username_input" - ).value + cfg_kwargs["gdrive_root_folder_id"] = ( + self.get_config_value_from_input_value( + "#configs_gdrive_root_folder_id" + ) + ) - cfg_kwargs["central_host_username"] = ( - None if central_host_username == "" else central_host_username - ) + # AWS specific + elif connection_method == "aws": + cfg_kwargs["aws_access_key_id"] = ( + self.get_config_value_from_input_value( + "#configs_aws_access_key_id_input" + ) + ) + + aws_region = self.query_one("#configs_aws_region_select").value + cfg_kwargs["aws_region"] = ( + None if aws_region == Select.BLANK else aws_region + ) return cfg_kwargs + + def get_config_value_from_input_value( + self, input_box_selector: str + ) -> str | None: + """Format the Input value from string to string or `None`.""" + input_value = self.query_one(input_box_selector).value + + return None if input_value == "" else input_value diff --git a/datashuttle/tui/tabs/create_folders.py b/datashuttle/tui/tabs/create_folders.py index 2ec76e72..767e070a 100644 --- a/datashuttle/tui/tabs/create_folders.py +++ b/datashuttle/tui/tabs/create_folders.py @@ -20,6 +20,7 @@ Label, ) +from datashuttle.configs import canonical_configs from datashuttle.tui.custom_widgets import ( ClickableInput, CustomDirectoryTree, @@ -325,11 +326,10 @@ def suggest_next_sub_ses( If `True`, search central project as well to generate the suggestion. """ - assert self.interface.project.cfg["connection_method"] in [ - None, - "local_filesystem", - "ssh", - ] + assert ( + self.interface.project.cfg["connection_method"] + in canonical_configs.get_connection_methods_list() + ) if ( include_central diff --git a/datashuttle/tui/tooltips.py b/datashuttle/tui/tooltips.py index f39bb9a4..46d9dcae 100644 --- a/datashuttle/tui/tooltips.py +++ b/datashuttle/tui/tooltips.py @@ -62,6 +62,34 @@ def get_tooltip(id: str) -> str: "to a project folder, possibly on a mounted drive.\n\n" ) + elif id == "config_central_path_input_mode-aws": + tooltip = ( + "The path to the project folder within the aws bucket.\n" + "Leave blank if the aws bucket is the project folder." + ) + + elif id == "config_central_path_input_mode-gdrive": + tooltip = ( + "The path to the project folder within the google drive folder.\n" + "Leave blank if the google drive folder is the project folder." + ) + # Google Drive configs + # ------------------------------------------------------------------------- + + # Google Drive Client ID + elif id == "#configs_gdrive_client_id_input": + tooltip = ( + "The Google Drive Client ID to use for authentication.\n\n" + "It can be obtained by creating an OAuth 2.0 client in the Google Cloud Console.\n\n" + "Can be left empty to use rclone's default client (slower)" + ) + + elif id == "#configs_gdrive_root_folder_id": + tooltip = ( + "The Google Drive root folder ID to use for transfer.\n\n" + "It can be obtained by navigating to the folder in Google Drive and copying the ID from the URL.\n\n" + ) + # Settings # ------------------------------------------------------------------------- diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py new file mode 100644 index 00000000..519a2b6d --- /dev/null +++ b/datashuttle/utils/aws.py @@ -0,0 +1,60 @@ +import json + +from datashuttle.configs.config_class import Configs +from datashuttle.utils import rclone, utils +from datashuttle.utils.custom_exceptions import ConfigError + + +def check_if_aws_bucket_exists(cfg: Configs) -> bool: + """Determine whether the AWS bucket actually exists on the server. + + The first part of`cfg["central_path"] should be an existing bucket name. + """ + output = rclone.call_rclone( + f"lsjson {cfg.get_rclone_config_name()}:", pipe_std=True + ) + + files_and_folders = json.loads(output.stdout) + + names = list(map(lambda x: x.get("Name", None), files_and_folders)) + + bucket_name = get_aws_bucket_name(cfg) + + if bucket_name not in names: + return False + + return True + + +def raise_if_bucket_absent(cfg: Configs) -> None: + """Raise and log error if the AWS bucket is not found on the server.""" + if not check_if_aws_bucket_exists(cfg): + bucket_name = get_aws_bucket_name(cfg) + utils.log_and_raise_error( + f'The bucket "{bucket_name}" does not exist.\n' + f"For data transfer to happen, the bucket must exist.\n" + f"Please change the bucket name in the `central_path`.", + ConfigError, + ) + + +def get_aws_bucket_name(cfg: Configs) -> str: + """Return the formatted AWS bucket name from the `central_path`.""" + return cfg["central_path"].as_posix().strip("/").split("/")[0] + + +# ----------------------------------------------------------------------------- +# For Python API +# ----------------------------------------------------------------------------- + + +def get_aws_secret_access_key(log: bool = True) -> str: + """Return the user-input AWS secret access key.""" + aws_secret_access_key = utils.get_connection_secret_from_user( + connection_method_name="AWS", + key_name_full="AWS secret access key", + key_name_short="secret key", + log_status=log, + ) + + return aws_secret_access_key.strip() diff --git a/datashuttle/utils/decorators.py b/datashuttle/utils/decorators.py index 6fecb6ce..17375884 100644 --- a/datashuttle/utils/decorators.py +++ b/datashuttle/utils/decorators.py @@ -29,6 +29,28 @@ def wrapper(*args, **kwargs): return wrapper +def requires_aws_configs(func): + """Check Amazon Web Service configs have been set.""" + + @wraps(func) + def wrapper(*args, **kwargs): + if ( + not args[0].cfg["aws_access_key_id"] + or not args[0].cfg["aws_region"] + ): + log_and_raise_error( + "Cannot setup AWS connection, 'aws_access_key_id' " + "or 'aws_region' is not set in the " + "configuration file", + ConfigError, + ) + + else: + return func(*args, **kwargs) + + return wrapper + + def check_configs_set(func): """Check configs have been set.""" diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index a8ff8827..c5abf696 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import ( TYPE_CHECKING, Any, @@ -20,7 +21,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import ssh, utils, validation +from datashuttle.utils import rclone, ssh, utils, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -598,14 +599,27 @@ def search_for_folders( Discovered folders (`all_folder_names`) and files (`all_filenames`). """ - if local_or_central == "central" and cfg["connection_method"] == "ssh": - all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, - ) + if local_or_central == "central" and cfg["connection_method"] in [ + "ssh", + "gdrive", + "aws", + ]: + if cfg["connection_method"] == "ssh": + all_folder_names, all_filenames = ( + ssh.search_ssh_central_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + ) + + else: + all_folder_names, all_filenames = search_gdrive_or_aws_for_folders( + search_path, search_prefix, cfg, return_full_path + ) + else: if not search_path.exists(): if verbose: @@ -620,6 +634,57 @@ def search_for_folders( return all_folder_names, all_filenames +def search_gdrive_or_aws_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """Search for files and folders in central path using `rclone lsjson` command. + + This command lists all the files and folders in the central path in a json format. + The json contains file/folder info about each file/folder like name, type, etc. + """ + output = rclone.call_rclone( + "lsjson " + f"{cfg.get_rclone_config_name()}:{search_path.as_posix()} " + f'--include "{search_prefix}"', + pipe_std=True, + ) + + all_folder_names: List[str] = [] + all_filenames: List[str] = [] + + if output.returncode != 0: + utils.log_and_message( + f"Error searching files at {search_path.as_posix()} \n {output.stderr.decode('utf-8') if output.stderr else ''}" + ) + return all_folder_names, all_filenames + + files_and_folders = json.loads(output.stdout) + + try: + for file_or_folder in files_and_folders: + name = file_or_folder["Name"] + is_dir = file_or_folder.get("IsDir", False) + + to_append = ( + (search_path / name).as_posix() if return_full_path else name + ) + + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except Exception: + utils.log_and_message( + f"Error searching files at {search_path.as_posix()}" + ) + + return all_folder_names, all_filenames + + # Actual function implementation def search_filesystem_path_for_folders( search_path_with_prefix: Path, return_full_path: bool = False diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py new file mode 100644 index 00000000..d0247946 --- /dev/null +++ b/datashuttle/utils/gdrive.py @@ -0,0 +1,51 @@ +from datashuttle.utils import utils + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# Python API +# ----------------------------------------------------------------------------- + + +def ask_user_for_browser(log: bool = True) -> bool: + """Ask the user if they have access to an internet browser, for Google Drive set up.""" + message = "Are you running Datashuttle on a machine with access to a web browser? (y/n): " + input_ = utils.get_user_input(message).lower() + + while input_ not in ["y", "n"]: + utils.print_message_to_user("Invalid input. Press either 'y' or 'n'.") + input_ = utils.get_user_input(message).lower() + + answer = input_ == "y" + + if log: + utils.log(message) + utils.log(f"User answer: {answer}") + + return answer + + +def prompt_and_get_service_account_filepath(log: bool = True): + """Get service account filepath from user.""" + message = "Please enter your service account file path: " + input_ = utils.get_user_input(message).strip() + + if log: + utils.log(message) + utils.log(f"Service account file at: {input_}") + + return input_ + + +def get_client_secret(log: bool = True) -> str: + """Prompt the user for their Google Drive client secret key.""" + gdrive_client_secret = utils.get_connection_secret_from_user( + connection_method_name="Google Drive", + key_name_full="Google Drive client secret", + key_name_short="secret key", + log_status=log, + ) + + return gdrive_client_secret.strip() diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index c6b57b24..59ed05ee 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -1,15 +1,22 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, List, Literal, Optional + +if TYPE_CHECKING: + from pathlib import Path + + from datashuttle.configs.config_class import Configs + from datashuttle.utils.custom_types import TopLevelFolder + import os import platform +import shlex import subprocess import tempfile -from pathlib import Path from subprocess import CompletedProcess -from typing import Dict, List, Literal from datashuttle.configs import canonical_configs -from datashuttle.configs.config_class import Configs from datashuttle.utils import utils -from datashuttle.utils.custom_types import TopLevelFolder def call_rclone(command: str, pipe_std: bool = False) -> CompletedProcess: @@ -88,6 +95,39 @@ def call_rclone_through_script(command: str) -> CompletedProcess: return output +def call_rclone_with_popen(command: str) -> subprocess.Popen: + """Call rclone using `subprocess.Popen` for control over process termination. + + It is not possible to kill a process while running it using `subprocess.run`. + Killing a process might be required when running rclone setup in a thread worker + to allow the user to cancel the setup process. In such a case, cancelling the + thread worker alone will not kill the rclone process, so we need to kill the + process explicitly. + """ + command = "rclone " + command + process = subprocess.Popen( + shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + return process + + +def await_call_rclone_with_popen_raise_on_fail( + process: subprocess.Popen, log: bool = True +): + """Await rclone the subprocess.Popen call. + + Calling `process.communicate()` waits for the process to complete and returns + the stdout and stderr. + """ + stdout, stderr = process.communicate() + + if process.returncode != 0: + utils.log_and_raise_error(stderr.decode("utf-8"), ConnectionError) + + if log: + log_rclone_config_output() + + # ----------------------------------------------------------------------------- # Setup # ----------------------------------------------------------------------------- @@ -171,6 +211,138 @@ def setup_rclone_config_for_ssh( log_rclone_config_output() +def setup_rclone_config_for_gdrive( + cfg: Configs, + rclone_config_name: str, + gdrive_client_secret: str | None, + service_account_filepath: Optional[str] = None, +) -> subprocess.Popen: + """Set up rclone config for connections to Google Drive. + + This function uses `call_rclone_with_popen` instead of `call_rclone`. This + is done to have more control over the setup process in case the user wishes to + cancel the setup. Since the rclone setup for google drive uses a local web server + for authentication to google drive, the running process must be killed before the + setup can be started again. + + Parameters + ---------- + cfg + datashuttle configs UserDict. This must contain the `gdrive_root_folder_id` + and optionally a `gdrive_client_id` which also mandates for the presence + of a Google Drive client secret. + + rclone_config_name + Canonical config name, generated by + datashuttle.cfg.get_rclone_config_name() + + gdrive_client_secret + Google Drive client secret, mandatory when using a Google Drive client. + + service_account_filepath : path to service account file path for connection + without browser + + """ + client_id_key_value = ( + f"client_id {cfg['gdrive_client_id']} " + if cfg["gdrive_client_id"] + else " " + ) + client_secret_key_value = ( + f"client_secret {gdrive_client_secret} " + if gdrive_client_secret + else "" + ) + + service_account_filepath_arg = ( + "" + if service_account_filepath is None + else f"service_account_file {service_account_filepath}" + ) + + process = call_rclone_with_popen( + f"config create " + f"{rclone_config_name} " + f"drive " + f"{client_id_key_value}" + f"{client_secret_key_value}" + f"scope drive " + f"root_folder_id {cfg['gdrive_root_folder_id']} " + f"{service_account_filepath_arg}", + ) + + return process + + +def setup_rclone_config_for_aws( + cfg: Configs, + rclone_config_name: str, + aws_secret_access_key: str, + log: bool = True, +): + """Set up rclone config for connections to AWS S3 buckets. + + Parameters + ---------- + cfg + datashuttle configs UserDict. + Must contain the `aws_access_key_id` and `aws_region`. + + rclone_config_name + Canonical RClone config name, generated by + datashuttle.cfg.get_rclone_config_name() + + aws_secret_access_key + The aws secret access key provided by the user. + + log + Whether to log, if True logger must already be initialised. + + """ + output = call_rclone( + "config create " + f"{rclone_config_name} " + "s3 provider AWS " + f"access_key_id {cfg['aws_access_key_id']} " + f"secret_access_key {aws_secret_access_key} " + f"region {cfg['aws_region']} " + f"location_constraint {cfg['aws_region']}", + pipe_std=True, + ) + + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + if log: + log_rclone_config_output() + + +def check_successful_connection_and_raise_error_on_fail(cfg: Configs) -> None: + """Check for a successful connection by creating a file on the remote. + + If the command fails, it raises a ConnectionError. The created file is + deleted thereafter. + """ + tempfile_path = (cfg["central_path"] / "temp.txt").as_posix() + output = call_rclone( + f"touch {cfg.get_rclone_config_name()}:{tempfile_path}", pipe_std=True + ) + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + output = call_rclone( + f"delete {cfg.get_rclone_config_name()}:{tempfile_path}", pipe_std=True + ) + if output.returncode != 0: + utils.log_and_raise_error( + output.stderr.decode("utf-8"), ConnectionError + ) + + def log_rclone_config_output() -> None: """Log the output from creating Rclone config.""" output = call_rclone("config file", pipe_std=True) diff --git a/datashuttle/utils/ssh.py b/datashuttle/utils/ssh.py index 3cbf1c49..40f43ef5 100644 --- a/datashuttle/utils/ssh.py +++ b/datashuttle/utils/ssh.py @@ -6,9 +6,7 @@ from datashuttle.configs.config_class import Configs import fnmatch -import getpass import stat -import sys from pathlib import Path from typing import Any, List, Optional, Tuple @@ -198,30 +196,16 @@ def setup_ssh_key( log if True, logger must already be initialised. """ - if not sys.stdin.isatty(): - proceed = input( - "\nWARNING!\nThe next step is to enter a password, but it is not possible\n" - "to hide your password while entering it in the current terminal.\n" - "This can occur if running the command in an IDE.\n\n" - "Press 'y' to proceed to password entry. " - "The characters will not be hidden!\n" - "Alternatively, run ssh setup after starting Python in your " - "system terminal \nrather than through an IDE: " - ) - if proceed != "y": - utils.print_message_to_user( - "Quitting SSH setup as 'y' not pressed." - ) - return - - password = input( - "Please enter your password. Characters will not be hidden: " - ) - else: - password = getpass.getpass( - "Please enter password to your central host to add the public key. " + password = utils.get_connection_secret_from_user( + connection_method_name="SSH", + key_name_full="password", + key_name_short="password", + key_info=( + "You are required to enter the password to your central host to add the public key. " "You will not have to enter your password again." - ) + ), + log_status=log, + ) add_public_key_to_central_authorized_keys(cfg, password) diff --git a/datashuttle/utils/utils.py b/datashuttle/utils/utils.py index 7cc9df54..b90e2cec 100644 --- a/datashuttle/utils/utils.py +++ b/datashuttle/utils/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations +import getpass import re +import sys import traceback import warnings from typing import TYPE_CHECKING, Any, List, Literal, Union, overload @@ -104,6 +106,73 @@ def get_user_input(message: str) -> str: return input_ +def get_connection_secret_from_user( + connection_method_name: str, + key_name_full: str, + key_name_short: str, + key_info: str | None = None, + log_status: bool = True, +) -> str: + """Get sensitive information input from the user via their terminal. + + This is a centralised function shared across connection methods. + It checks whether the standard input (stdin) is connected to a + terminal or not. If not, the user is displayed a warning and asked + if they would like to continue. + + Parameters + ---------- + connection_method_name + A string identifying the connection method being used. + + key_name_full + Full name of the connection secret being asked from the user. + + key_name_short + Short name of the connection secret to avoid repeatedly writing the full name. + + key_info + Extra info about the connection secret that needs to intimated to the user. + + log_status + Log if `True`, logger must already be initialised. + + """ + if key_info: + print_message_to_user(key_info) + + if not sys.stdin.isatty(): + proceed = input( + f"\nWARNING!\nThe next step is to enter a {key_name_full}, but it is not possible\n" + f"to hide your {key_name_short} while entering it in the current terminal.\n" + f"This can occur if running the command in an IDE.\n\n" + f"Press 'y' to proceed to {key_name_short} entry. " + f"The characters will not be hidden!\n" + f"Alternatively, run {connection_method_name} setup after starting Python in your " + f"system terminal \nrather than through an IDE: " + ) + if proceed != "y": + print_message_to_user( + f"Quitting {connection_method_name} setup as 'y' not pressed." + ) + log_and_raise_error( + f"{connection_method_name} setup aborted by user.", + ConnectionAbortedError, + ) + + input_ = input( + f"Please enter your {key_name_full}. Characters will not be hidden: " + ) + + else: + input_ = getpass.getpass(f"Please enter your {key_name_full}: ") + + if log_status: + log(f"{key_name_full} entered by user.") + + return input_ + + # ----------------------------------------------------------------------------- # Paths # ----------------------------------------------------------------------------- diff --git a/tests/test_utils.py b/tests/test_utils.py index 48a7617d..a75fef0d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -659,3 +659,21 @@ def make_project(project_name): project = DataShuttle(project_name) warnings.filterwarnings("default") return project + + +def monkeypatch_get_datashuttle_path(tmp_config_path, _monkeypatch): + """Monkeypatch the function that creates a hidden datashuttle folder. + + By default, datashuttle saves project folders to + Path.home() / .datashuttle. In order to not mess with + the home directory during this test the `get_datashuttle_path()` + function is monkeypatched in order to point to a tmp_path. + """ + + def mock_get_datashuttle_path(): + return tmp_config_path + + _monkeypatch.setattr( + "datashuttle.configs.canonical_folders.get_datashuttle_path", + mock_get_datashuttle_path, + ) diff --git a/tests/tests_integration/test_configs.py b/tests/tests_integration/test_configs.py index 0c8f91a4..2208c2a2 100644 --- a/tests/tests_integration/test_configs.py +++ b/tests/tests_integration/test_configs.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pytest @@ -222,16 +223,14 @@ def test_existing_projects(self, monkeypatch, tmp_path): """Test existing projects are correctly found based on whether they exist in the home directory and contain a config.yaml. - By default, datashuttle saves project folders to - Path.home() / .datashuttle. In order to not mess with - the home directory during this test the `get_datashuttle_path()` - function is monkeypatched in order to point to a tmp_path. - The tmp_path / "projects" is filled with a mix of project folders with and without config, and tested against accordingly. The `local_path` and `central_path` specified in the DataShuttle config are arbitrarily put in `tmp_path`. """ + test_utils.monkeypatch_get_datashuttle_path( + tmp_path / "projects", monkeypatch + ) def patch_get_datashuttle_path(): return tmp_path / "projects" @@ -266,6 +265,69 @@ def patch_get_datashuttle_path(): (tmp_path / "projects" / "project_3"), ] + # Test Connection Method + # ------------------------------------------------------------- + + @pytest.mark.parametrize( + "connection_method", [None, "local_filesystem", "ssh", "gdrive", "aws"] + ) + def test_connection_method_required_args( + self, tmp_path, no_cfg_project, connection_method + ): + """Test that error is not raised when `central_path` is `None` for `aws` and `gdrive`.""" + if connection_method in ["local_filesystem", "ssh"]: + with pytest.raises(ConfigError): + no_cfg_project.make_config_file( + local_path=tmp_path / "local", + connection_method=connection_method, + central_path=None, + ) + else: + no_cfg_project.make_config_file( + local_path=tmp_path / "local", + connection_method=connection_method, + central_path=None, + gdrive_root_folder_id="placeholder", + aws_access_key_id="placeholder", + aws_region="us-east-1", + ) + assert no_cfg_project.cfg["central_path"] is None + + @pytest.mark.parametrize( + "connection_method", ["local_filesystem", "ssh", "gdrive", "aws"] + ) + def test_get_base_folder( + self, tmp_path, no_cfg_project, connection_method + ): + """Test central `get_base_folder()` which can depend on the connection method. + + For `aws` and `gdrive`, it is possible for `central_path` to be `None` as + the folder my be the project folder itself. In this case, check that + `get_base_folder()` returns the expected path. + """ + project_name = no_cfg_project.project_name + central_path = ( + None if connection_method in ["gdrive", "aws"] else tmp_path + ) + + no_cfg_project.make_config_file( + local_path=tmp_path / "local", + connection_method=connection_method, + central_path=central_path, + central_host_id="placeholder", + central_host_username="placeholder", + gdrive_root_folder_id="placeholder", + aws_access_key_id="placeholder", + aws_region="us-east-1", + ) + + folder = no_cfg_project.cfg.get_base_folder("central", "rawdata") + + if connection_method in ["ssh", "local_filesystem"]: + assert folder == central_path / project_name / "rawdata" + else: + assert folder == Path("rawdata") + # ------------------------------------------------------------------------- # Utils # ------------------------------------------------------------------------- diff --git a/tests/tests_transfers/ssh/ssh_test_utils.py b/tests/tests_transfers/ssh/ssh_test_utils.py index 5cbf141b..e785ed11 100644 --- a/tests/tests_transfers/ssh/ssh_test_utils.py +++ b/tests/tests_transfers/ssh/ssh_test_utils.py @@ -6,7 +6,7 @@ import paramiko -from datashuttle.utils import rclone, ssh +from datashuttle.utils import rclone, ssh, utils def setup_project_for_ssh( @@ -43,8 +43,8 @@ def setup_ssh_connection(project, setup_ssh_key_pair=True): orig_builtin = copy.deepcopy(builtins.input) builtins.input = lambda _: "y" # type: ignore - orig_getpass = copy.deepcopy(ssh.getpass.getpass) - ssh.getpass.getpass = lambda _: "password" # type: ignore + orig_get_secret = copy.deepcopy(utils.get_connection_secret_from_user) + utils.get_connection_secret_from_user = lambda *args, **kwargs: "password" # type: ignore orig_isatty = copy.deepcopy(sys.stdin.isatty) sys.stdin.isatty = lambda: True @@ -59,7 +59,7 @@ def setup_ssh_connection(project, setup_ssh_key_pair=True): # Restore functions builtins.input = orig_builtin - ssh.getpass.getpass = orig_getpass + utils.get_connection_secret_from_user = orig_get_secret sys.stdin.isatty = orig_isatty rclone.setup_rclone_config_for_ssh( diff --git a/tests/tests_tui/test_local_only_project.py b/tests/tests_tui/test_local_only_project.py index a9d47594..f4394499 100644 --- a/tests/tests_tui/test_local_only_project.py +++ b/tests/tests_tui/test_local_only_project.py @@ -35,6 +35,10 @@ async def test_local_only_make_project( "connection_method": None, "central_host_id": None, "central_host_username": None, + "gdrive_client_id": None, + "gdrive_root_folder_id": None, + "aws_access_key_id": None, + "aws_region": None, } assert pilot.app.screen.query_one( "#placeholder_transfer_tab" @@ -112,6 +116,10 @@ async def test_local_project_to_full( "connection_method": "ssh", "central_host_id": "some_host", "central_host_username": "some_username", + "gdrive_client_id": None, + "gdrive_root_folder_id": None, + "aws_access_key_id": None, + "aws_region": None, } await pilot.pause() @@ -174,6 +182,10 @@ async def test_full_project_to_local( "connection_method": None, "central_host_id": None, "central_host_username": None, + "gdrive_client_id": None, + "gdrive_root_folder_id": None, + "aws_access_key_id": None, + "aws_region": None, } await pilot.pause() diff --git a/tests/tests_tui/test_tui_configs.py b/tests/tests_tui/test_tui_configs.py index 5a335314..aa63554e 100644 --- a/tests/tests_tui/test_tui_configs.py +++ b/tests/tests_tui/test_tui_configs.py @@ -107,9 +107,9 @@ async def test_make_new_project_configs( await self.close_messagebox(pilot) assert ( pilot.app.screen.query_one( - "#configs_setup_ssh_connection_button" - ).visible - is True + "#configs_setup_connection_button" + ).label + == "Setup SSH Connection" ) else: assert ( diff --git a/tests/tests_tui/test_tui_widgets_and_defaults.py b/tests/tests_tui/test_tui_widgets_and_defaults.py index 13d9511a..92a04a83 100644 --- a/tests/tests_tui/test_tui_widgets_and_defaults.py +++ b/tests/tests_tui/test_tui_widgets_and_defaults.py @@ -210,7 +210,7 @@ async def check_new_project_ssh_widgets( self, configs_content, ssh_on, save_pressed=False ): assert configs_content.query_one( - "#configs_setup_ssh_connection_button" + "#configs_setup_connection_button" ).visible is ( ssh_on and save_pressed ) # Only enabled after project creation. diff --git a/tests/tests_tui/tui_base.py b/tests/tests_tui/tui_base.py index 4452b52a..05bb2556 100644 --- a/tests/tests_tui/tui_base.py +++ b/tests/tests_tui/tui_base.py @@ -27,12 +27,23 @@ def tui_size(self): async def empty_project_paths(self, tmp_path_factory, monkeypatch): """Get the paths and project name for a non-existent (i.e. not yet setup) project. + + For these tests, store the datashuttle configs (usually stored in + Path.home()) in the `tmp_path` provided by pytest, as it simplifies + testing here. + + This is not done for general tests because + 1) It is further from the actual datashuttle behaviour + 2) It fails for testing CLI, because CLI spawns a new process in + which `get_datashuttle_path()` is not monkeypatched. """ project_name = "my-test-project" tmp_path = tmp_path_factory.mktemp("test") tmp_config_path = tmp_path / "config" - self.monkeypatch_get_datashuttle_path(tmp_config_path, monkeypatch) + test_utils.monkeypatch_get_datashuttle_path( + tmp_config_path, monkeypatch + ) self.monkeypatch_print(monkeypatch) assert not any(list(tmp_config_path.glob("**"))) @@ -53,25 +64,6 @@ async def setup_project_paths(self, empty_project_paths): return empty_project_paths - def monkeypatch_get_datashuttle_path(self, tmp_config_path, _monkeypatch): - """For these tests, store the datashuttle configs (usually stored in - Path.home()) in the `tmp_path` provided by pytest, as it simplifies - testing here. - - This is not done for general tests because - 1) It is further from the actual datashuttle behaviour - 2) It fails for testing CLI, because CLI spawns a new process in - which `get_datashuttle_path()` is not monkeypatched. - """ - - def mock_get_datashuttle_path(): - return tmp_config_path - - _monkeypatch.setattr( - "datashuttle.configs.canonical_folders.get_datashuttle_path", - mock_get_datashuttle_path, - ) - def monkeypatch_print(self, _monkeypatch): """Calls to `print` in datashuttle crash the TUI in the test environment. I am not sure why. Get around this