Skip to content

Commit f04a05c

Browse files
authored
Enable Support for Multi-GPU Training (#517)
* Add `pickle_protocol` to `DataConfig` for passing it down to `torch.save()` when caching datasets to disk * Save the index of the original dataframe in `TabularDataset` so that it can be restored when accessing `TabularDataset.data` * Add `sync_dist=True` to all calls to self.log() in `validation_step()` and `test_step()` to distributed training * Fix `TrainerConfig.precision` to be a string and remove integer choices. Add a pointer to docs with possible options * Add `sync_dist` to SSL Models * Address `PerformanceWarning` related to `frame.insert()` * Only load best checkpoint on rank zero in distributed training * if logging with wandb, `unwatch` the model after training * address `FutureWarning` re inplace=True
1 parent b504132 commit f04a05c

File tree

8 files changed

+66
-31
lines changed

8 files changed

+66
-31
lines changed

src/pytorch_tabular/categorical_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def transform(self, X):
6868
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6969

7070
if self.handle_unseen == "impute":
71-
X_encoded[col].fillna(self._imputed, inplace=True)
71+
X_encoded[col] = X_encoded[col].fillna(self._imputed)
7272
elif self.handle_unseen == "error":
7373
if np.unique(X_encoded[col]).shape[0] > mapping.shape[0]:
7474
raise ValueError(f"Unseen categories found in `{col}` column.")

src/pytorch_tabular/config/config.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class DataConfig:
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
9898
99+
pickle_protocol (int): pickle protocol version passed to `torch.save` for dataset caching to disk
100+
99101
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100102
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101103
@@ -179,6 +181,11 @@ class DataConfig:
179181
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
180182
)
181183

184+
pickle_protocol: int = field(
185+
default=2,
186+
metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"},
187+
)
188+
182189
dataloader_kwargs: Dict[str, Any] = field(
183190
default_factory=dict,
184191
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
@@ -351,8 +358,8 @@ class TrainerConfig:
351358
352359
progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`.
353360
354-
precision (int): Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`..
355-
Choices are: [`32`,`16`,`64`].
361+
precision (str): Precision of the model. Defaults to `32`. See
362+
https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
356363
357364
seed (int): Seed for random number generators. Defaults to 42
358365
@@ -536,11 +543,10 @@ class TrainerConfig:
536543
default="rich",
537544
metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
538545
)
539-
precision: int = field(
540-
default=32,
546+
precision: str = field(
547+
default="32",
541548
metadata={
542-
"help": "Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.",
543-
"choices": [32, 16, 64],
549+
"help": "Precision of the model. Defaults to `32`.",
544550
},
545551
)
546552
seed: int = field(

src/pytorch_tabular/feature_extractor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,21 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
7979
if k in ret_value.keys():
8080
logits_predictions[k].append(ret_value[k].detach().cpu())
8181

82+
logits_dfs = []
8283
for k, v in logits_predictions.items():
8384
v = torch.cat(v, dim=0).numpy()
8485
if v.ndim == 1:
8586
v = v.reshape(-1, 1)
86-
for i in range(v.shape[-1]):
87-
if v.shape[-1] > 1:
88-
X_encoded[f"{k}_{i}"] = v[:, i]
89-
else:
90-
X_encoded[f"{k}"] = v[:, i]
87+
if v.shape[-1] > 1:
88+
temp_df = pd.DataFrame({f"{k}_{i}": v[:, i] for i in range(v.shape[-1])})
89+
else:
90+
temp_df = pd.DataFrame({f"{k}": v[:, 0]})
91+
92+
# Append the temp DataFrame to the list
93+
logits_dfs.append(temp_df)
94+
95+
preds = pd.concat(logits_dfs, axis=1)
96+
X_encoded = pd.concat([X_encoded, preds], axis=1)
9197

9298
if self.drop_original:
9399
X_encoded.drop(columns=orig_features, inplace=True)

src/pytorch_tabular/models/base_model.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,14 @@ def _setup_metrics(self):
244244
else:
245245
self.metrics = self.custom_metrics
246246

247-
def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tensor:
247+
def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str, sync_dist: bool = False) -> torch.Tensor:
248248
"""Calculates the loss for the model.
249249
250250
Args:
251251
output (Dict): The output dictionary from the model
252252
y (torch.Tensor): The target tensor
253253
tag (str): The tag to use for logging
254+
sync_dist (bool): enable distributed sync of logs
254255
255256
Returns:
256257
torch.Tensor: The loss value
@@ -270,6 +271,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
270271
on_step=False,
271272
logger=True,
272273
prog_bar=False,
274+
sync_dist=sync_dist,
273275
)
274276
if self.hparams.task == "regression":
275277
computed_loss = reg_loss
@@ -284,6 +286,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
284286
on_step=False,
285287
logger=True,
286288
prog_bar=False,
289+
sync_dist=sync_dist,
287290
)
288291
else:
289292
# TODO loss fails with batch size of 1?
@@ -301,6 +304,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
301304
on_step=False,
302305
logger=True,
303306
prog_bar=False,
307+
sync_dist=sync_dist,
304308
)
305309
start_index = end_index
306310
self.log(
@@ -311,10 +315,13 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
311315
# on_step=False,
312316
logger=True,
313317
prog_bar=True,
318+
sync_dist=sync_dist,
314319
)
315320
return computed_loss
316321

317-
def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> List[torch.Tensor]:
322+
def calculate_metrics(
323+
self, y: torch.Tensor, y_hat: torch.Tensor, tag: str, sync_dist: bool = False
324+
) -> List[torch.Tensor]:
318325
"""Calculates the metrics for the model.
319326
320327
Args:
@@ -324,6 +331,8 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
324331
325332
tag (str): The tag to use for logging
326333
334+
sync_dist (bool): enable distributed sync of logs
335+
327336
Returns:
328337
List[torch.Tensor]: The list of metric values
329338
@@ -356,6 +365,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
356365
on_step=False,
357366
logger=True,
358367
prog_bar=False,
368+
sync_dist=sync_dist,
359369
)
360370
_metrics.append(_metric)
361371
avg_metric = torch.stack(_metrics, dim=0).sum()
@@ -379,6 +389,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
379389
on_step=False,
380390
logger=True,
381391
prog_bar=False,
392+
sync_dist=sync_dist,
382393
)
383394
_metrics.append(_metric)
384395
start_index = end_index
@@ -391,6 +402,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
391402
on_step=False,
392403
logger=True,
393404
prog_bar=True,
405+
sync_dist=sync_dist,
394406
)
395407
return metrics
396408

@@ -523,19 +535,19 @@ def validation_step(self, batch, batch_idx):
523535
# fetched from the batch
524536
y = batch["target"] if y is None else y
525537
y_hat = output["logits"]
526-
self.calculate_loss(output, y, tag="valid")
527-
self.calculate_metrics(y, y_hat, tag="valid")
538+
self.calculate_loss(output, y, tag="valid", sync_dist=True)
539+
self.calculate_metrics(y, y_hat, tag="valid", sync_dist=True)
528540
return y_hat, y
529541

530542
def test_step(self, batch, batch_idx):
531543
with torch.no_grad():
532544
output, y = self.forward_pass(batch)
533-
# y is not None for SSL task.Rest of the tasks target is
545+
# y is not None for SSL task. Rest of the tasks target is
534546
# fetched from the batch
535547
y = batch["target"] if y is None else y
536548
y_hat = output["logits"]
537-
self.calculate_loss(output, y, tag="test")
538-
self.calculate_metrics(y, y_hat, tag="test")
549+
self.calculate_loss(output, y, tag="test", sync_dist=True)
550+
self.calculate_metrics(y, y_hat, tag="test", sync_dist=True)
539551
return y_hat, y
540552

541553
def configure_optimizers(self):

src/pytorch_tabular/ssl_models/base_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ def _setup_metrics(self):
136136
pass
137137

138138
@abstractmethod
139-
def calculate_loss(self, output, tag):
139+
def calculate_loss(self, output, tag, sync_dist):
140140
pass
141141

142142
@abstractmethod
143-
def calculate_metrics(self, output, tag):
143+
def calculate_metrics(self, output, tag, sync_dist):
144144
pass
145145

146146
@abstractmethod
@@ -167,15 +167,15 @@ def training_step(self, batch, batch_idx):
167167
def validation_step(self, batch, batch_idx):
168168
with torch.no_grad():
169169
output = self.forward(batch)
170-
self.calculate_loss(output, tag="valid")
171-
self.calculate_metrics(output, tag="valid")
170+
self.calculate_loss(output, tag="valid", sync_dist=True)
171+
self.calculate_metrics(output, tag="valid", sync_dist=True)
172172
return output
173173

174174
def test_step(self, batch, batch_idx):
175175
with torch.no_grad():
176176
output = self.forward(batch)
177-
self.calculate_loss(output, tag="test")
178-
self.calculate_metrics(output, tag="test")
177+
self.calculate_loss(output, tag="test", sync_dist=True)
178+
self.calculate_metrics(output, tag="test", sync_dist=True)
179179
return output
180180

181181
def on_validation_epoch_end(self) -> None:

src/pytorch_tabular/ssl_models/dae/dae.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def forward(self, x: Dict):
200200
else:
201201
return z.features
202202

203-
def calculate_loss(self, output, tag):
203+
def calculate_loss(self, output, tag, sync_dist=False):
204204
total_loss = 0
205205
for type_, out in output.items():
206206
if type_ == "categorical":
@@ -220,6 +220,7 @@ def calculate_loss(self, output, tag):
220220
on_step=False,
221221
logger=True,
222222
prog_bar=False,
223+
sync_dist=sync_dist,
223224
)
224225
total_loss += loss
225226
self.log(
@@ -230,10 +231,11 @@ def calculate_loss(self, output, tag):
230231
# on_step=False,
231232
logger=True,
232233
prog_bar=True,
234+
sync_dist=sync_dist,
233235
)
234236
return total_loss
235237

236-
def calculate_metrics(self, output, tag):
238+
def calculate_metrics(self, output, tag, sync_dist=False):
237239
pass
238240

239241
def featurize(self, x: Dict):

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
self.task = task
6262
self.n = data.shape[0]
6363
self.target = target
64+
self.index = data.index
6465
if target:
6566
self.y = data[target].astype(np.float32).values
6667
if isinstance(target, str):
@@ -87,11 +88,12 @@ def data(self):
8788
data = pd.DataFrame(
8889
np.concatenate([self.categorical_X, self.continuous_X], axis=1),
8990
columns=self.categorical_cols + self.continuous_cols,
91+
index=self.index,
9092
)
9193
elif self.continuous_cols:
92-
data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols)
94+
data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols, index=self.index)
9395
elif self.categorical_cols:
94-
data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols)
96+
data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols, index=self.index)
9597
else:
9698
data = pd.DataFrame()
9799
for i, t in enumerate(self.target):
@@ -474,6 +476,7 @@ def _cache_dataset(self):
474476
target=self.target,
475477
)
476478
self.train = None
479+
477480
validation_dataset = TabularDataset(
478481
task=self.config.task,
479482
data=self.validation,
@@ -484,8 +487,10 @@ def _cache_dataset(self):
484487
self.validation = None
485488

486489
if self.cache_mode is self.CACHE_MODES.DISK:
487-
torch.save(train_dataset, self.cache_dir / "train_dataset")
488-
torch.save(validation_dataset, self.cache_dir / "validation_dataset")
490+
torch.save(train_dataset, self.cache_dir / "train_dataset", pickle_protocol=self.config.pickle_protocol)
491+
torch.save(
492+
validation_dataset, self.cache_dir / "validation_dataset", pickle_protocol=self.config.pickle_protocol
493+
)
489494
elif self.cache_mode is self.CACHE_MODES.MEMORY:
490495
self.train_dataset = train_dataset
491496
self.validation_dataset = validation_dataset

src/pytorch_tabular/tabular_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from pytorch_lightning.tuner.tuning import Tuner
3333
from pytorch_lightning.utilities.model_summary import summarize
34+
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3435
from rich import print as rich_print
3536
from rich.pretty import pprint
3637
from sklearn.base import TransformerMixin
@@ -685,6 +686,8 @@ def train(
685686
"/n" + "Original Error: " + oom_handler.oom_msg
686687
)
687688
self._is_fitted = True
689+
if self.track_experiment and self.config.log_target == "wandb":
690+
self.logger.experiment.unwatch(self.model)
688691
if self.verbose:
689692
logger.info("Training the model completed")
690693
if self.config.load_best:
@@ -1522,6 +1525,7 @@ def add_noise(module, input, output):
15221525
)
15231526
return pred_df
15241527

1528+
@rank_zero_only
15251529
def load_best_model(self) -> None:
15261530
"""Loads the best model after training is done."""
15271531
if self.trainer.checkpoint_callback is not None:

0 commit comments

Comments
 (0)