Skip to content

Commit 3d978f9

Browse files
authored
fix for ssl finetuning bug (#510)
1 parent 495803c commit 3d978f9

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def setup(self, stage: Optional[str] = None) -> None:
526526
else:
527527
self.validation = self.validation.copy()
528528
# Preprocessing Train, Validation
529-
self.train, _ = self.preprocess_data(self.train, stage="fit" if not is_ssl else "inference")
529+
self.train, _ = self.preprocess_data(self.train, stage="inference" if is_ssl else "fit")
530530
self.validation, _ = self.preprocess_data(self.validation, stage="inference")
531531
self._fitted = True
532532
self._cache_dataset()

src/pytorch_tabular/tabular_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,13 +1001,15 @@ def create_finetune_model(
10011001
logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}")
10021002
config["run_name"] = config["run_name"] + "_finetuned"
10031003

1004+
config_override = {"target": target} if target is not None else {}
1005+
config_override["task"] = task
10041006
datamodule = self.datamodule.copy(
10051007
train=train,
10061008
validation=validation,
10071009
target_transform=target_transform,
10081010
train_sampler=train_sampler,
10091011
seed=seed,
1010-
config_override={"target": target} if target is not None else {},
1012+
config_override=config_override,
10111013
)
10121014
model_callable = _GenericModel
10131015
inferred_config = OmegaConf.structured(datamodule._inferred_config)

0 commit comments

Comments
 (0)