Skip to content

Commit 4187bbd

Browse files
authored
Merge pull request #239 from changbiHub/self_training_fix
Self_training_fix
2 parents 4d4dadf + 882db8b commit 4187bbd

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pytorch_widedeep/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ def forward(self, x_true: Tensor, x_pred: Tensor, mask: Tensor) -> Tensor:
959959
x_true_means[x_true_means == 0] = 1
960960

961961
x_true_stds = torch.std(x_true, dim=0) ** 2
962-
x_true_stds[x_true_stds == 0] = x_true_means[x_true_stds == 0]
962+
x_true_stds[x_true_stds == 0] = torch.abs(x_true_means[x_true_stds == 0])
963963

964964
features_loss = torch.matmul(reconstruction_errors, 1 / x_true_stds)
965965
nb_reconstructed_variables = torch.sum(mask, dim=1)

pytorch_widedeep/models/tabular/self_supervised/encoder_decoder_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _forward_tabnet(self, X: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
8282
x_embed_rec = self.decoder(steps_out)
8383
mask = torch.ones(x_embed.shape).to(X.device)
8484

85-
return x_embed_rec, x_embed, mask
85+
return x_embed, x_embed_rec, mask
8686

8787
def _build_decoder(self, encoder: ModelWithoutAttention) -> DecoderWithoutAttention:
8888
if isinstance(encoder, TabMlp):

0 commit comments

Comments
 (0)