From 9396e1dae547e3eb5ed86515387cb2674522a384 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Jun 2025 20:52:12 +0000 Subject: [PATCH 1/2] [pre-commit.ci] pre-commit suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/PyCQA/docformatter: 06907d0267368b49b9180eed423fae5697c1e909 → v1.7.7](https://github.com/PyCQA/docformatter/compare/06907d0267368b49b9180eed423fae5697c1e909...v1.7.7) - [github.com/executablebooks/mdformat: 0.7.19 → 0.7.22](https://github.com/executablebooks/mdformat/compare/0.7.19...0.7.22) - [github.com/astral-sh/ruff-pre-commit: v0.8.3 → v0.11.13](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.3...v0.11.13) --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb5d4fda..9a738be1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,14 +27,14 @@ repos: - id: detect-private-key - repo: https://github.com/PyCQA/docformatter - rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5 + rev: v1.7.7 # todo: fix for docformatter after last 1.7.5 hooks: - id: docformatter additional_dependencies: [tomli] args: ["--in-place"] - repo: https://github.com/executablebooks/mdformat - rev: 0.7.19 + rev: 0.7.22 hooks: - id: mdformat additional_dependencies: @@ -48,7 +48,7 @@ repos: ) - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.3 + rev: v0.11.13 hooks: - id: ruff args: ["--fix"] From a3f7f0cf305c19cf00f1e56aabb820fa4508f371 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Jun 2025 20:52:21 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_tabular/categorical_encoders.py | 12 ++-- src/pytorch_tabular/config/config.py | 12 ++-- .../models/category_embedding/config.py | 2 +- .../models/common/layers/embeddings.py | 36 +++++----- src/pytorch_tabular/models/gate/config.py | 2 +- src/pytorch_tabular/models/gate/gate_model.py | 12 ++-- .../models/mixture_density/config.py | 6 +- src/pytorch_tabular/models/tabnet/config.py | 2 +- .../models/tabnet/tabnet_model.py | 2 +- src/pytorch_tabular/ssl_models/base_model.py | 6 +- .../ssl_models/common/layers.py | 6 +- src/pytorch_tabular/ssl_models/dae/config.py | 6 +- src/pytorch_tabular/tabular_model.py | 66 ++++++++----------- src/pytorch_tabular/tabular_model_sweep.py | 48 +++++++------- src/pytorch_tabular/tabular_model_tuner.py | 14 ++-- src/pytorch_tabular/utils/nn_utils.py | 4 +- src/pytorch_tabular/utils/python_utils.py | 2 +- 17 files changed, 114 insertions(+), 124 deletions(-) diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index b3d7a1ee..4376356b 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -58,9 +58,9 @@ def transform(self, X): raise ValueError("`fit` method must be called before `transform`.") assert all(c in X.columns for c in self.cols) if self.handle_missing == "error": - assert ( - not X[self.cols].isnull().any().any() - ), "`handle_missing` = `error` and missing values found in columns to encode." + assert not X[self.cols].isnull().any().any(), ( + "`handle_missing` = `error` and missing values found in columns to encode." + ) X_encoded = X.copy(deep=True) category_cols = X_encoded.select_dtypes(include="category").columns X_encoded[category_cols] = X_encoded[category_cols].astype("object") @@ -153,9 +153,9 @@ def fit(self, X, y=None): """ self._before_fit_check(X, y) if self.handle_missing == "error": - assert ( - not X[self.cols].isnull().any().any() - ), "`handle_missing` = `error` and missing values found in columns to encode." + assert not X[self.cols].isnull().any().any(), ( + "`handle_missing` = `error` and missing values found in columns to encode." + ) for col in self.cols: map = Series(unique(X[col].fillna(NAN_CATEGORY)), name=col).reset_index().rename(columns={"index": "value"}) map["value"] += 1 diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 999c2c4a..ed8235b2 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -192,9 +192,9 @@ class DataConfig: ) def __post_init__(self): - assert ( - len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0 - ), "There should be at-least one feature defined in categorical, continuous, or date columns" + assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, ( + "There should be at-least one feature defined in categorical, continuous, or date columns" + ) _validate_choices(self) if os.name == "nt" and self.num_workers != 0: print("Windows does not support num_workers > 0. Setting num_workers to 0") @@ -255,9 +255,9 @@ class InferredConfig: def __post_init__(self): if self.embedding_dims is not None: - assert all( - (isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims - ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), ( + "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + ) self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims]) else: self.embedded_cat_dim = 0 diff --git a/src/pytorch_tabular/models/category_embedding/config.py b/src/pytorch_tabular/models/category_embedding/config.py index 99b77b29..02411a13 100644 --- a/src/pytorch_tabular/models/category_embedding/config.py +++ b/src/pytorch_tabular/models/category_embedding/config.py @@ -98,7 +98,7 @@ class CategoryEmbeddingModelConfig(ModelConfig): ) use_batch_norm: bool = field( default=False, - metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut." " Defaults to False")}, + metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False")}, ) initialization: str = field( default="kaiming", diff --git a/src/pytorch_tabular/models/common/layers/embeddings.py b/src/pytorch_tabular/models/common/layers/embeddings.py index bbc7f2ba..a8e44e61 100644 --- a/src/pytorch_tabular/models/common/layers/embeddings.py +++ b/src/pytorch_tabular/models/common/layers/embeddings.py @@ -84,12 +84,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor: x.get("continuous", torch.empty(0, 0)), x.get("categorical", torch.empty(0, 0)), ) - assert ( - categorical_data.shape[1] == self.categorical_dim - ), "categorical_data must have same number of columns as categorical embedding layers" - assert ( - continuous_data.shape[1] == self.continuous_dim - ), "continuous_data must have same number of columns as continuous dim" + assert categorical_data.shape[1] == self.categorical_dim, ( + "categorical_data must have same number of columns as categorical embedding layers" + ) + assert continuous_data.shape[1] == self.continuous_dim, ( + "continuous_data must have same number of columns as continuous dim" + ) embed = None if continuous_data.shape[1] > 0: if self.batch_norm_continuous_input: @@ -141,12 +141,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor: x.get("continuous", torch.empty(0, 0)), x.get("categorical", torch.empty(0, 0)), ) - assert categorical_data.shape[1] == len( - self.cat_embedding_layers - ), "categorical_data must have same number of columns as categorical embedding layers" - assert ( - continuous_data.shape[1] == self.continuous_dim - ), "continuous_data must have same number of columns as continuous dim" + assert categorical_data.shape[1] == len(self.cat_embedding_layers), ( + "categorical_data must have same number of columns as categorical embedding layers" + ) + assert continuous_data.shape[1] == self.continuous_dim, ( + "continuous_data must have same number of columns as continuous dim" + ) embed = None if continuous_data.shape[1] > 0: if self.batch_norm_continuous_input: @@ -273,12 +273,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor: x.get("continuous", torch.empty(0, 0)), x.get("categorical", torch.empty(0, 0)), ) - assert categorical_data.shape[1] == len( - self.cat_embedding_layers - ), "categorical_data must have same number of columns as categorical embedding layers" - assert ( - continuous_data.shape[1] == self.continuous_dim - ), "continuous_data must have same number of columns as continuous dim" + assert categorical_data.shape[1] == len(self.cat_embedding_layers), ( + "categorical_data must have same number of columns as categorical embedding layers" + ) + assert continuous_data.shape[1] == self.continuous_dim, ( + "continuous_data must have same number of columns as continuous dim" + ) embed = None if continuous_data.shape[1] > 0: cont_idx = torch.arange(self.continuous_dim, device=continuous_data.device).expand( diff --git a/src/pytorch_tabular/models/gate/config.py b/src/pytorch_tabular/models/gate/config.py index b8ba9729..03840546 100644 --- a/src/pytorch_tabular/models/gate/config.py +++ b/src/pytorch_tabular/models/gate/config.py @@ -173,7 +173,7 @@ def __post_init__(self): assert self.tree_depth > 0, "tree_depth should be greater than 0" # Either gflu_stages or num_trees should be greater than 0 assert self.num_trees > 0, ( - "`num_trees` must be greater than 0." "If you want a lighter model which performs better, use GANDALF." + "`num_trees` must be greater than 0.If you want a lighter model which performs better, use GANDALF." ) super().__post_init__() diff --git a/src/pytorch_tabular/models/gate/gate_model.py b/src/pytorch_tabular/models/gate/gate_model.py index 7acf7065..07d25bb2 100644 --- a/src/pytorch_tabular/models/gate/gate_model.py +++ b/src/pytorch_tabular/models/gate/gate_model.py @@ -51,12 +51,12 @@ def __init__( embedding_dropout: float = 0.0, ): super().__init__() - assert ( - binning_activation in self.BINARY_ACTIVATION_MAP.keys() - ), f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}" - assert ( - feature_mask_function in self.ACTIVATION_MAP.keys() - ), f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}" + assert binning_activation in self.BINARY_ACTIVATION_MAP.keys(), ( + f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}" + ) + assert feature_mask_function in self.ACTIVATION_MAP.keys(), ( + f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}" + ) self.gflu_stages = gflu_stages self.num_trees = num_trees diff --git a/src/pytorch_tabular/models/mixture_density/config.py b/src/pytorch_tabular/models/mixture_density/config.py index 428e5871..409927d6 100644 --- a/src/pytorch_tabular/models/mixture_density/config.py +++ b/src/pytorch_tabular/models/mixture_density/config.py @@ -87,9 +87,9 @@ class MDNConfig(ModelConfig): _probabilistic: bool = field(default=True) def __post_init__(self): - assert ( - self.backbone_config_class not in INCOMPATIBLE_BACKBONES - ), f"{self.backbone_config_class} is not a supported backbone for MDN head" + assert self.backbone_config_class not in INCOMPATIBLE_BACKBONES, ( + f"{self.backbone_config_class} is not a supported backbone for MDN head" + ) assert self.head == "MixtureDensityHead" return super().__post_init__() diff --git a/src/pytorch_tabular/models/tabnet/config.py b/src/pytorch_tabular/models/tabnet/config.py index c1142273..ea55bb5b 100644 --- a/src/pytorch_tabular/models/tabnet/config.py +++ b/src/pytorch_tabular/models/tabnet/config.py @@ -90,7 +90,7 @@ class TabNetModelConfig(ModelConfig): ) gamma: float = field( default=1.3, - metadata={"help": ("Float above 1, scaling factor for attention updates (usually between" " 1.0 to 2.0)")}, + metadata={"help": ("Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0)")}, ) n_independent: int = field( default=2, diff --git a/src/pytorch_tabular/models/tabnet/tabnet_model.py b/src/pytorch_tabular/models/tabnet/tabnet_model.py index a11a117e..f70d1d05 100644 --- a/src/pytorch_tabular/models/tabnet/tabnet_model.py +++ b/src/pytorch_tabular/models/tabnet/tabnet_model.py @@ -104,4 +104,4 @@ def _build_network(self): self._head = nn.Identity() def extract_embedding(self): - raise ValueError("Extracting Embeddings is not supported by Tabnet. Please use another" " compatible model") + raise ValueError("Extracting Embeddings is not supported by Tabnet. Please use another compatible model") diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 6b9150a7..75773e75 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -85,9 +85,9 @@ def __init__( self._setup_metrics() def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config): - assert (encoder is not None) or ( - encoder_config is not None - ), "Either encoder or encoder_config must be provided" + assert (encoder is not None) or (encoder_config is not None), ( + "Either encoder or encoder_config must be provided" + ) # assert (decoder is not None) or (decoder_config is not None), # "Either decoder or decoder_config must be provided" if encoder is not None: diff --git a/src/pytorch_tabular/ssl_models/common/layers.py b/src/pytorch_tabular/ssl_models/common/layers.py index a3e8f97e..e03aeb17 100644 --- a/src/pytorch_tabular/ssl_models/common/layers.py +++ b/src/pytorch_tabular/ssl_models/common/layers.py @@ -79,9 +79,9 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor: assert categorical_data.shape[1] == len( self._onehot_feat_idx + self._binary_feat_idx + self._embedding_feat_idx ), "categorical_data must have same number of columns as categorical embedding layers" - assert ( - continuous_data.shape[1] == self.continuous_dim - ), "continuous_data must have same number of columns as continuous dim" + assert continuous_data.shape[1] == self.continuous_dim, ( + "continuous_data must have same number of columns as continuous dim" + ) # embed = None if continuous_data.shape[1] > 0: if self.batch_norm_continuous_input: diff --git a/src/pytorch_tabular/ssl_models/dae/config.py b/src/pytorch_tabular/ssl_models/dae/config.py index b1f74885..00f007d6 100644 --- a/src/pytorch_tabular/ssl_models/dae/config.py +++ b/src/pytorch_tabular/ssl_models/dae/config.py @@ -125,9 +125,9 @@ class DenoisingAutoEncoderConfig(SSLModelConfig): def __post_init__(self): assert hasattr(self.encoder_config, "_backbone_name"), "encoder_config should have a _backbone_name attribute" if self.decoder_config is not None: - assert hasattr( - self.decoder_config, "_backbone_name" - ), "decoder_config should have a _backbone_name attribute" + assert hasattr(self.decoder_config, "_backbone_name"), ( + "decoder_config should have a _backbone_name attribute" + ) super().__post_init__() diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 0b34adf4..08d95aa2 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -140,7 +140,7 @@ def __init__( optimizer_config = self._read_parse_config(optimizer_config, OptimizerConfig) if model_config.task != "ssl": assert data_config.target is not None, ( - "`target` in data_config should not be None for" f" {model_config.task} task" + f"`target` in data_config should not be None for {model_config.task} task" ) if experiment_config is None: if self.verbose: @@ -284,9 +284,7 @@ def _setup_experiment_tracking(self): offline=False, ) else: - raise NotImplementedError( - f"{self.config.log_target} is not implemented. Try one of [wandb," " tensorboard]" - ) + raise NotImplementedError(f"{self.config.log_target} is not implemented. Try one of [wandb, tensorboard]") def _prepare_callbacks(self, callbacks=None) -> List: """Prepares the necesary callbacks to the Trainer based on the configuration. @@ -374,11 +372,9 @@ def _check_and_set_target_transform(self, target_transform): elif isinstance(target_transform, TransformerMixin): pass else: - raise ValueError( - "`target_transform` should wither be an sklearn Transformer or a" " tuple of callables." - ) + raise ValueError("`target_transform` should wither be an sklearn Transformer or a tuple of callables.") if self.config.task == "classification" and target_transform is not None: - logger.warning("For classification task, target transform is not used. Ignoring the" " parameter") + logger.warning("For classification task, target transform is not used. Ignoring the parameter") target_transform = None return target_transform @@ -772,12 +768,12 @@ def fit( """ assert self.config.task != "ssl", ( - "`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning" + "`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning" ) if metrics is not None: - assert len(metrics) == len( - metrics_prob_inputs or [] - ), "The length of `metrics` and `metrics_prob_inputs` should be equal" + assert len(metrics) == len(metrics_prob_inputs or []), ( + "The length of `metrics` and `metrics_prob_inputs` should be equal" + ) seed = seed or self.config.seed if seed: seed_everything(seed) @@ -855,7 +851,7 @@ def pretrain( """ assert self.config.task == "ssl", ( - f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" " instead." + f"`pretrain` is not valid for {self.config.task} task. Please use `fit` instead." ) seed = seed or self.config.seed if seed: @@ -976,9 +972,9 @@ def create_finetune_model( config = self.config optimizer_params = optimizer_params or {} if target is None: - assert ( - hasattr(config, "target") and config.target is not None - ), "`target` cannot be None if it was not set in the initial `DataConfig`" + assert hasattr(config, "target") and config.target is not None, ( + "`target` cannot be None if it was not set in the initial `DataConfig`" + ) else: assert isinstance(target, list), "`target` should be a list of strings" config.target = target @@ -1001,7 +997,7 @@ def create_finetune_model( if self.track_experiment: # Renaming the experiment run so that a different log is created for finetuning if self.verbose: - logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}") + logger.info(f"Renaming the experiment run for finetuning as {config['run_name'] + '_finetuned'}") config["run_name"] = config["run_name"] + "_finetuned" config_override = {"target": target} if target is not None else {} @@ -1106,7 +1102,7 @@ def finetune( """ assert self._is_finetune_model, ( - "finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`" + "finetune() can only be called on a finetune model created using `TabularModel.create_finetune_model()`" ) seed_everything(self.config.seed) if freeze_backbone: @@ -1294,7 +1290,7 @@ def _format_predicitons( ) if is_probabilistic: for j, q in enumerate(quantiles): - col_ = f"{target_col}_q{int(q*100)}" + col_ = f"{target_col}_q{int(q * 100)}" pred_df[col_] = self.datamodule.target_transforms[i].inverse_transform( quantile_predictions[:, j, i].reshape(-1, 1) ) @@ -1302,7 +1298,7 @@ def _format_predicitons( pred_df[f"{target_col}_prediction"] = point_predictions[:, i] if is_probabilistic: for j, q in enumerate(quantiles): - pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1) + pred_df[f"{target_col}_q{int(q * 100)}"] = quantile_predictions[:, j, i].reshape(-1, 1) elif self.config.task == "classification": start_index = 0 @@ -1483,7 +1479,7 @@ def predict( "min", "max", "hard_voting", - ], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'" + ], "aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'" if self.config.task == "regression": assert aggregate_tta != "hard_voting", "hard_voting is only available for classification" @@ -1538,11 +1534,9 @@ def load_best_model(self) -> None: ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) self.model.load_state_dict(ckpt["state_dict"]) else: - logger.warning("No best model available to load. Did you run it more than 1" " epoch?...") + logger.warning("No best model available to load. Did you run it more than 1 epoch?...") else: - logger.warning( - "No best model available to load. Checkpoint Callback needs to be" " enabled for this to work" - ) + logger.warning("No best model available to load. Checkpoint Callback needs to be enabled for this to work") def save_datamodule(self, dir: str, inference_only: bool = False) -> None: """Saves the datamodule in the specified directory. @@ -1707,7 +1701,7 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str: summary_str += "Config\n" summary_str += "-" * 100 + "\n" summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True) - summary_str += "\nFull Model Summary once model has been " "initialized or passed in as an argument" + summary_str += "\nFull Model Summary once model has been initialized or passed in as an argument" return summary_str def __str__(self) -> str: @@ -1936,9 +1930,7 @@ def _prepare_baselines_captum( else: baselines = baselines.mean(dim=0, keepdim=True) else: - raise ValueError( - "Invalid value for `baselines`. Please refer to the documentation" " for more details." - ) + raise ValueError("Invalid value for `baselines`. Please refer to the documentation for more details.") return baselines def _handle_categorical_embeddings_attributions( @@ -2061,9 +2053,7 @@ def explain( hasattr(self.model.hparams, "embedding_dims") and self.model.hparams.embedding_dims is not None ) if (not is_embedding1d) and (not is_embedding2d): - raise NotImplementedError( - "Attributions are not implemented for models with this type of" " embedding layer" - ) + raise NotImplementedError("Attributions are not implemented for models with this type of embedding layer") test_dl = self.datamodule.prepare_inference_dataloader(data) self.model.eval() # prepare import for Captum @@ -2095,7 +2085,7 @@ def explain( "Something went wrong. The number of features in the attributions" f" ({attributions.shape[1]}) does not match the number of features in" " the model" - f" ({self.model.hparams.continuous_dim+self.model.hparams.categorical_dim})" + f" ({self.model.hparams.continuous_dim + self.model.hparams.categorical_dim})" ) return pd.DataFrame( attributions.detach().cpu().numpy(), @@ -2215,7 +2205,7 @@ def cross_validate( oof_preds = [] for fold, (train_idx, val_idx) in it: if verbose: - logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}") + logger.info(f"Running Fold {fold + 1}/{cv.get_n_splits()}") # train_fold = train.iloc[train_idx] # val_fold = train.iloc[val_idx] if reset_datamodule: @@ -2247,7 +2237,7 @@ def cross_validate( result = self.evaluate(train.iloc[val_idx], verbose=False) cv_metrics.append(result[0][metric]) if verbose: - logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}") + logger.info(f"Fold {fold + 1}/{cv.get_n_splits()} score: {cv_metrics[-1]}") self.model.reset_weights() return cv_metrics, oof_preds @@ -2376,7 +2366,7 @@ def bagging_predict( ], "Bagging is only available for classification and regression" if not callable(aggregate): assert aggregate in ["mean", "median", "min", "max", "hard_voting"], ( - "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'" + "aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'" ) if self.config.task == "regression": assert aggregate != "hard_voting", "hard_voting is only available for classification" @@ -2387,7 +2377,7 @@ def bagging_predict( model = None for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)): if verbose: - logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}") + logger.info(f"Running Fold {fold + 1}/{cv.get_n_splits()}") train_fold = train.iloc[train_idx] val_fold = train.iloc[val_idx] if reset_datamodule: @@ -2412,7 +2402,7 @@ def bagging_predict( elif self.config.task == "regression": pred_prob_l.append(fold_preds.values) if verbose: - logger.info(f"Fold {fold+1}/{cv.get_n_splits()} prediction done") + logger.info(f"Fold {fold + 1}/{cv.get_n_splits()} prediction done") self.model.reset_weights() pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate, weights) if return_raw_predictions: diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index b6e749f2..3a8896e5 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -100,55 +100,55 @@ def _validate_args( assert isinstance(train, pd.DataFrame), f"train must be a pandas DataFrame, but got {type(train)}" assert isinstance(test, pd.DataFrame), f"test must be a pandas DataFrame, but got {type(test)}" assert model_list is not None, "models cannot be None" - assert isinstance( - model_list, (str, list) - ), f"models must be a string or list of strings, but got {type(model_list)}" + assert isinstance(model_list, (str, list)), ( + f"models must be a string or list of strings, but got {type(model_list)}" + ) if isinstance(model_list, str): - assert ( - model_list in MODEL_SWEEP_PRESETS.keys() - ), f"models must be one of {MODEL_SWEEP_PRESETS.keys()}, but got {model_list}" + assert model_list in MODEL_SWEEP_PRESETS.keys(), ( + f"models must be one of {MODEL_SWEEP_PRESETS.keys()}, but got {model_list}" + ) else: # isinstance(models, list): - assert all( - isinstance(m, (str, ModelConfig)) for m in model_list - ), f"models must be a list of strings or ModelConfigs, but got {model_list}" + assert all(isinstance(m, (str, ModelConfig)) for m in model_list), ( + f"models must be a list of strings or ModelConfigs, but got {model_list}" + ) assert all(task == m.task for m in model_list if isinstance(m, ModelConfig)), ( f"task must be the same as the task in ModelConfig, but got {task} and" f" {[m.task for m in model_list if isinstance(m, ModelConfig)]}" ) if metrics is not None: assert isinstance(metrics, list), f"metrics must be a list of strings or callables, but got {type(metrics)}" - assert all( - isinstance(m, (str, Callable)) for m in metrics - ), f"metrics must be a list of strings or callables, but got {metrics}" + assert all(isinstance(m, (str, Callable)) for m in metrics), ( + f"metrics must be a list of strings or callables, but got {metrics}" + ) assert metrics_params is not None, "metric_params cannot be None when metrics is not None" assert metrics_prob_input is not None, "metrics_prob_inputs cannot be None when metrics is not None" - assert isinstance( - metrics_params, list - ), f"metric_params must be a list of dicts, but got {type(metrics_params)}" + assert isinstance(metrics_params, list), ( + f"metric_params must be a list of dicts, but got {type(metrics_params)}" + ) assert isinstance(metrics_prob_input, list), ( - "metrics_prob_inputs must be a list of bools, but got" f" {type(metrics_prob_input)}" + f"metrics_prob_inputs must be a list of bools, but got {type(metrics_prob_input)}" ) assert len(metrics) == len(metrics_params), ( - "metrics and metric_params must be of the same length, but got" f" {len(metrics)} and {len(metrics_params)}" + f"metrics and metric_params must be of the same length, but got {len(metrics)} and {len(metrics_params)}" ) assert len(metrics) == len(metrics_prob_input), ( "metrics and metrics_prob_inputs must be of the same length, but got" f" {len(metrics)} and {len(metrics_prob_input)}" ) - assert all( - isinstance(m, dict) for m in metrics_params - ), f"metric_params must be a list of dicts, but got {metrics_params}" + assert all(isinstance(m, dict) for m in metrics_params), ( + f"metric_params must be a list of dicts, but got {metrics_params}" + ) if common_model_args is not None: # all args should be members of ModelConfig assert all(k in ModelConfig.__dataclass_fields__.keys() for k in common_model_args.keys()), ( - "common_model_args must be a subset of ModelConfig, but got" f" {common_model_args.keys()}" + f"common_model_args must be a subset of ModelConfig, but got {common_model_args.keys()}" ) if rank_metric[0] not in ["loss", "accuracy", "mean_squared_error"]: assert rank_metric[0] in metrics, f"rank_metric must be one of {metrics}, but got {rank_metric}" assert rank_metric[1] in [ "lower_is_better", "higher_is_better", - ], "rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but" f" got {rank_metric[1]}" + ], f"rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but got {rank_metric[1]}" def model_sweep( @@ -318,7 +318,7 @@ def _init_tabular_model(m): setattr(model_config, key, val) else: raise ValueError( - f"ModelConfig {model_config.name} does not have an" f" attribute {key} in common_model_args" + f"ModelConfig {model_config.name} does not have an attribute {key} in common_model_args" ) params = model_config.__dict__ start_time = time.time() @@ -380,7 +380,7 @@ def _init_tabular_model(m): if verbose: logger.info(f"Finished Training {name}") - logger.info("Results:" f" {', '.join([f'{k}: {v}' for k, v in res_dict.items()])}") + logger.info(f"Results: {', '.join([f'{k}: {v}' for k, v in res_dict.items()])}") res_dict["params"] = params results.append(res_dict) diff --git a/src/pytorch_tabular/tabular_model_tuner.py b/src/pytorch_tabular/tabular_model_tuner.py index d199d1fb..c62a2c76 100644 --- a/src/pytorch_tabular/tabular_model_tuner.py +++ b/src/pytorch_tabular/tabular_model_tuner.py @@ -236,9 +236,9 @@ def tune( assert strategy in self.ALLOWABLE_STRATEGIES, f"tuner must be one of {self.ALLOWABLE_STRATEGIES}" assert mode in ["max", "min"], "mode must be one of ['max', 'min']" assert metric is not None, "metric must be specified" - assert (isinstance(search_space, dict) or (isinstance(search_space, list))) and len( - search_space - ) > 0, "search_space must be a non-empty dict" + assert (isinstance(search_space, dict) or (isinstance(search_space, list))) and len(search_space) > 0, ( + "search_space must be a non-empty dict" + ) if self.suppress_lightning_logger: suppress_lightning_logs() if cv is not None and validation is not None: @@ -273,9 +273,9 @@ def tune( } if strategy == "grid_search": - assert all( - isinstance(v, list) for v in search_space_temp.values() - ), "For grid search, all values in search_space must be a list of values to try" + assert all(isinstance(v, list) for v in search_space_temp.values()), ( + "For grid search, all values in search_space must be a list of values to try" + ) search_space_iterator = list(ParameterGrid(search_space_temp)) if n_trials is not None: warnings.warn( @@ -409,7 +409,7 @@ def tune( params.update({"trial_id": i}) trials.append(params) if verbose: - logger.info(f"Trial {i+1}/{n_trials}: {params} | Score: {params[metric_str]}") + logger.info(f"Trial {i + 1}/{n_trials}: {params} | Score: {params[metric_str]}") trials_df = pd.DataFrame(trials) trials = trials_df.pop("trial_id") diff --git a/src/pytorch_tabular/utils/nn_utils.py b/src/pytorch_tabular/utils/nn_utils.py index c4506b8d..e2f248d4 100644 --- a/src/pytorch_tabular/utils/nn_utils.py +++ b/src/pytorch_tabular/utils/nn_utils.py @@ -18,7 +18,7 @@ def _initialize_layers(activation, initialization, layers): nonlinearity = "leaky_relu" else: if initialization == "kaiming": - logger.warning("Kaiming initialization is only recommended for ReLU and" " LeakyReLU.") + logger.warning("Kaiming initialization is only recommended for ReLU and LeakyReLU.") nonlinearity = "leaky_relu" else: nonlinearity = "relu" @@ -108,7 +108,7 @@ def _initialize_kaiming(x, initialization, d_sqrt_inv): elif initialization is None: pass else: - raise NotImplementedError("initialization should be either of `kaiming_normal`, `kaiming_uniform`," " `None`") + raise NotImplementedError("initialization should be either of `kaiming_normal`, `kaiming_uniform`, `None`") class OutOfMemoryHandler: diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index e08503ed..e1221aa5 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -47,7 +47,7 @@ def generate_doc_dataclass(dataclass, desc=None, width=100): type = str(atr.type).replace("", "").replace("typing.", "") help_str = atr.metadata.get("help", "") if "choices" in atr.metadata.keys(): - help_str += ". Choices are:" f" [{','.join(['`'+str(ch)+'`' for ch in atr.metadata['choices']])}]." + help_str += f". Choices are: [{','.join(['`' + str(ch) + '`' for ch in atr.metadata['choices']])}]." # help_str += f'. Defaults to {atr.default}' h_str = textwrap.fill( f"{key} ({type}): {help_str}",