Skip to content

Commit ddb317e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2315bd0 commit ddb317e

File tree

17 files changed

+114
-124
lines changed

17 files changed

+114
-124
lines changed

src/pytorch_tabular/categorical_encoders.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def transform(self, X):
5858
raise ValueError("`fit` method must be called before `transform`.")
5959
assert all(c in X.columns for c in self.cols)
6060
if self.handle_missing == "error":
61-
assert (
62-
not X[self.cols].isnull().any().any()
63-
), "`handle_missing` = `error` and missing values found in columns to encode."
61+
assert not X[self.cols].isnull().any().any(), (
62+
"`handle_missing` = `error` and missing values found in columns to encode."
63+
)
6464
X_encoded = X.copy(deep=True)
6565
category_cols = X_encoded.select_dtypes(include="category").columns
6666
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
@@ -153,9 +153,9 @@ def fit(self, X, y=None):
153153
"""
154154
self._before_fit_check(X, y)
155155
if self.handle_missing == "error":
156-
assert (
157-
not X[self.cols].isnull().any().any()
158-
), "`handle_missing` = `error` and missing values found in columns to encode."
156+
assert not X[self.cols].isnull().any().any(), (
157+
"`handle_missing` = `error` and missing values found in columns to encode."
158+
)
159159
for col in self.cols:
160160
map = Series(unique(X[col].fillna(NAN_CATEGORY)), name=col).reset_index().rename(columns={"index": "value"})
161161
map["value"] += 1

src/pytorch_tabular/config/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ class DataConfig:
192192
)
193193

194194
def __post_init__(self):
195-
assert (
196-
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
197-
), "There should be at-least one feature defined in categorical, continuous, or date columns"
195+
assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, (
196+
"There should be at-least one feature defined in categorical, continuous, or date columns"
197+
)
198198
_validate_choices(self)
199199
if os.name == "nt" and self.num_workers != 0:
200200
print("Windows does not support num_workers > 0. Setting num_workers to 0")
@@ -255,9 +255,9 @@ class InferredConfig:
255255

256256
def __post_init__(self):
257257
if self.embedding_dims is not None:
258-
assert all(
259-
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
260-
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
258+
assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), (
259+
"embedding_dims must be a list of tuples (cardinality, embedding_dim)"
260+
)
261261
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
262262
else:
263263
self.embedded_cat_dim = 0

src/pytorch_tabular/models/category_embedding/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class CategoryEmbeddingModelConfig(ModelConfig):
9898
)
9999
use_batch_norm: bool = field(
100100
default=False,
101-
metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut." " Defaults to False")},
101+
metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False")},
102102
)
103103
initialization: str = field(
104104
default="kaiming",

src/pytorch_tabular/models/common/layers/embeddings.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
8484
x.get("continuous", torch.empty(0, 0)),
8585
x.get("categorical", torch.empty(0, 0)),
8686
)
87-
assert (
88-
categorical_data.shape[1] == self.categorical_dim
89-
), "categorical_data must have same number of columns as categorical embedding layers"
90-
assert (
91-
continuous_data.shape[1] == self.continuous_dim
92-
), "continuous_data must have same number of columns as continuous dim"
87+
assert categorical_data.shape[1] == self.categorical_dim, (
88+
"categorical_data must have same number of columns as categorical embedding layers"
89+
)
90+
assert continuous_data.shape[1] == self.continuous_dim, (
91+
"continuous_data must have same number of columns as continuous dim"
92+
)
9393
embed = None
9494
if continuous_data.shape[1] > 0:
9595
if self.batch_norm_continuous_input:
@@ -141,12 +141,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
141141
x.get("continuous", torch.empty(0, 0)),
142142
x.get("categorical", torch.empty(0, 0)),
143143
)
144-
assert categorical_data.shape[1] == len(
145-
self.cat_embedding_layers
146-
), "categorical_data must have same number of columns as categorical embedding layers"
147-
assert (
148-
continuous_data.shape[1] == self.continuous_dim
149-
), "continuous_data must have same number of columns as continuous dim"
144+
assert categorical_data.shape[1] == len(self.cat_embedding_layers), (
145+
"categorical_data must have same number of columns as categorical embedding layers"
146+
)
147+
assert continuous_data.shape[1] == self.continuous_dim, (
148+
"continuous_data must have same number of columns as continuous dim"
149+
)
150150
embed = None
151151
if continuous_data.shape[1] > 0:
152152
if self.batch_norm_continuous_input:
@@ -273,12 +273,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
273273
x.get("continuous", torch.empty(0, 0)),
274274
x.get("categorical", torch.empty(0, 0)),
275275
)
276-
assert categorical_data.shape[1] == len(
277-
self.cat_embedding_layers
278-
), "categorical_data must have same number of columns as categorical embedding layers"
279-
assert (
280-
continuous_data.shape[1] == self.continuous_dim
281-
), "continuous_data must have same number of columns as continuous dim"
276+
assert categorical_data.shape[1] == len(self.cat_embedding_layers), (
277+
"categorical_data must have same number of columns as categorical embedding layers"
278+
)
279+
assert continuous_data.shape[1] == self.continuous_dim, (
280+
"continuous_data must have same number of columns as continuous dim"
281+
)
282282
embed = None
283283
if continuous_data.shape[1] > 0:
284284
cont_idx = torch.arange(self.continuous_dim, device=continuous_data.device).expand(

src/pytorch_tabular/models/gate/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def __post_init__(self):
173173
assert self.tree_depth > 0, "tree_depth should be greater than 0"
174174
# Either gflu_stages or num_trees should be greater than 0
175175
assert self.num_trees > 0, (
176-
"`num_trees` must be greater than 0." "If you want a lighter model which performs better, use GANDALF."
176+
"`num_trees` must be greater than 0.If you want a lighter model which performs better, use GANDALF."
177177
)
178178
super().__post_init__()
179179

src/pytorch_tabular/models/gate/gate_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def __init__(
5151
embedding_dropout: float = 0.0,
5252
):
5353
super().__init__()
54-
assert (
55-
binning_activation in self.BINARY_ACTIVATION_MAP.keys()
56-
), f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}"
57-
assert (
58-
feature_mask_function in self.ACTIVATION_MAP.keys()
59-
), f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}"
54+
assert binning_activation in self.BINARY_ACTIVATION_MAP.keys(), (
55+
f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}"
56+
)
57+
assert feature_mask_function in self.ACTIVATION_MAP.keys(), (
58+
f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}"
59+
)
6060

6161
self.gflu_stages = gflu_stages
6262
self.num_trees = num_trees

src/pytorch_tabular/models/mixture_density/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ class MDNConfig(ModelConfig):
8787
_probabilistic: bool = field(default=True)
8888

8989
def __post_init__(self):
90-
assert (
91-
self.backbone_config_class not in INCOMPATIBLE_BACKBONES
92-
), f"{self.backbone_config_class} is not a supported backbone for MDN head"
90+
assert self.backbone_config_class not in INCOMPATIBLE_BACKBONES, (
91+
f"{self.backbone_config_class} is not a supported backbone for MDN head"
92+
)
9393
assert self.head == "MixtureDensityHead"
9494
return super().__post_init__()
9595

src/pytorch_tabular/models/tabnet/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class TabNetModelConfig(ModelConfig):
9090
)
9191
gamma: float = field(
9292
default=1.3,
93-
metadata={"help": ("Float above 1, scaling factor for attention updates (usually between" " 1.0 to 2.0)")},
93+
metadata={"help": ("Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0)")},
9494
)
9595
n_independent: int = field(
9696
default=2,

src/pytorch_tabular/models/tabnet/tabnet_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,4 @@ def _build_network(self):
104104
self._head = nn.Identity()
105105

106106
def extract_embedding(self):
107-
raise ValueError("Extracting Embeddings is not supported by Tabnet. Please use another" " compatible model")
107+
raise ValueError("Extracting Embeddings is not supported by Tabnet. Please use another compatible model")

src/pytorch_tabular/ssl_models/base_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def __init__(
8585
self._setup_metrics()
8686

8787
def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config):
88-
assert (encoder is not None) or (
89-
encoder_config is not None
90-
), "Either encoder or encoder_config must be provided"
88+
assert (encoder is not None) or (encoder_config is not None), (
89+
"Either encoder or encoder_config must be provided"
90+
)
9191
# assert (decoder is not None) or (decoder_config is not None),
9292
# "Either decoder or decoder_config must be provided"
9393
if encoder is not None:

src/pytorch_tabular/ssl_models/common/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
7979
assert categorical_data.shape[1] == len(
8080
self._onehot_feat_idx + self._binary_feat_idx + self._embedding_feat_idx
8181
), "categorical_data must have same number of columns as categorical embedding layers"
82-
assert (
83-
continuous_data.shape[1] == self.continuous_dim
84-
), "continuous_data must have same number of columns as continuous dim"
82+
assert continuous_data.shape[1] == self.continuous_dim, (
83+
"continuous_data must have same number of columns as continuous dim"
84+
)
8585
# embed = None
8686
if continuous_data.shape[1] > 0:
8787
if self.batch_norm_continuous_input:

src/pytorch_tabular/ssl_models/dae/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ class DenoisingAutoEncoderConfig(SSLModelConfig):
125125
def __post_init__(self):
126126
assert hasattr(self.encoder_config, "_backbone_name"), "encoder_config should have a _backbone_name attribute"
127127
if self.decoder_config is not None:
128-
assert hasattr(
129-
self.decoder_config, "_backbone_name"
130-
), "decoder_config should have a _backbone_name attribute"
128+
assert hasattr(self.decoder_config, "_backbone_name"), (
129+
"decoder_config should have a _backbone_name attribute"
130+
)
131131
super().__post_init__()
132132

133133

0 commit comments

Comments
 (0)