diff --git a/patchwork/app.py b/patchwork/app.py index 4149959a6..a7429243a 100644 --- a/patchwork/app.py +++ b/patchwork/app.py @@ -59,6 +59,7 @@ def list_option_callback(ctx: click.Context, param: click.Parameter, value: str def find_patchflow(possible_module_paths: Iterable[str], patchflow: str) -> Any | None: + allowed_modules = {"module1", "module2"} # Example whitelist of allowed modules for module_path in possible_module_paths: try: spec = importlib.util.spec_from_file_location("custom_module", module_path) @@ -71,14 +72,15 @@ def find_patchflow(possible_module_paths: Iterable[str], patchflow: str) -> Any except Exception: logger.debug(f"Patchflow {patchflow} not found as a file/directory in {module_path}") - try: - module = importlib.import_module(module_path) - logger.info(f"Patchflow {patchflow} loaded from {module_path}") - return getattr(module, patchflow) - except ModuleNotFoundError: - logger.debug(f"Patchflow {patchflow} not found as a module in {module_path}") - except AttributeError: - logger.debug(f"Patchflow {patchflow} not found in {module_path}") + if module_path in allowed_modules: + try: + module = importlib.import_module(module_path) + logger.info(f"Patchflow {patchflow} loaded from {module_path}") + return getattr(module, patchflow) + except ModuleNotFoundError: + logger.debug(f"Patchflow {patchflow} not found as a module in {module_path}") + except AttributeError: + logger.debug(f"Patchflow {patchflow} not found in {module_path}") return None diff --git a/patchwork/common/tools/bash_tool.py b/patchwork/common/tools/bash_tool.py index 8440f179a..9fb62fcbf 100644 --- a/patchwork/common/tools/bash_tool.py +++ b/patchwork/common/tools/bash_tool.py @@ -44,8 +44,9 @@ def execute( return f"Error: `command` parameter must be set and cannot be empty" try: + cmd_list = command.split() result = subprocess.run( - command, shell=True, cwd=self.path, capture_output=True, text=True, timeout=60 # Add timeout for safety + cmd_list, shell=False, cwd=self.path, capture_output=True, text=True, timeout=60 # Add timeout for safety ) return result.stdout if result.returncode == 0 else f"Error: {result.stderr}" except subprocess.TimeoutExpired: diff --git a/patchwork/common/tools/csvkit_tool.py b/patchwork/common/tools/csvkit_tool.py index a1ef8dc59..4f9566466 100644 --- a/patchwork/common/tools/csvkit_tool.py +++ b/patchwork/common/tools/csvkit_tool.py @@ -118,14 +118,15 @@ def execute(self, files: list[str], query: str) -> str: if db_path.is_file(): with sqlite3.connect(str(db_path)) as conn: for file in files: + table_name = file.removesuffix('.csv') res = conn.execute( - f"SELECT 1 from {file.removesuffix('.csv')}", + "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", (table_name,) ) if res.fetchone() is None: files_to_insert.append(file) else: files_to_insert = files - + if len(files_to_insert) > 0: p = subprocess.run( ["csvsql", *files_to_insert, "--db", db_url, "--insert"], capture_output=True, text=True, cwd=self.path diff --git a/patchwork/common/utils/dependency.py b/patchwork/common/utils/dependency.py index 27b89bfed..8044849d0 100644 --- a/patchwork/common/utils/dependency.py +++ b/patchwork/common/utils/dependency.py @@ -6,9 +6,12 @@ "notification": ["slack_sdk"], } +__WHITELIST = set(dep for deps in __DEPENDENCY_GROUPS.values() for dep in deps) @lru_cache(maxsize=None) def import_with_dependency_group(name): + if name not in __WHITELIST: + raise ImportError(f"Attempt to import untrusted or unavailable module: {name}") try: return importlib.import_module(name) except ImportError: @@ -20,6 +23,5 @@ def import_with_dependency_group(name): error_msg = f"Please `pip install patchwork-cli[{dependency_group}]` to use this step" raise ImportError(error_msg) - def slack_sdk(): return import_with_dependency_group("slack_sdk") diff --git a/patchwork/common/utils/step_typing.py b/patchwork/common/utils/step_typing.py index d349f7fc1..f143327fe 100644 --- a/patchwork/common/utils/step_typing.py +++ b/patchwork/common/utils/step_typing.py @@ -106,7 +106,10 @@ def validate_step_type_config_with_inputs( def validate_step_with_inputs(input_keys: Set[str], step: Type[Step]) -> Tuple[Set[str], Dict[str, str]]: + module_whitelist = {"allowed.module1", "allowed.module2"} module_path, _, _ = step.__module__.rpartition(".") + if module_path not in module_whitelist: + raise ImportError(f"Unauthorized module import attempt: {module_path}") step_name = step.__name__ type_module = importlib.import_module(f"{module_path}.typed") step_input_model = getattr(type_module, f"{step_name}Inputs", __NOT_GIVEN) diff --git a/patchwork/steps/CallShell/CallShell.py b/patchwork/steps/CallShell/CallShell.py index 98ee55a74..b93592ae9 100644 --- a/patchwork/steps/CallShell/CallShell.py +++ b/patchwork/steps/CallShell/CallShell.py @@ -46,7 +46,8 @@ def __parse_env_text(env_text: str) -> dict[str, str]: return env def run(self) -> dict: - p = subprocess.run(self.script, shell=True, capture_output=True, text=True, cwd=self.working_dir, env=self.env) + script_args = shlex.split(self.script) + p = subprocess.run(script_args, shell=False, capture_output=True, text=True, cwd=self.working_dir, env=self.env) try: p.check_returncode() except subprocess.CalledProcessError as e: