Skip to content

Commit 612ea27

Browse files
committed
bug in classifier.py
1 parent 12aa51a commit 612ea27

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pyreason/scripts/learning/classification/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def forward(self, x, t1: int = 0, t2: int = 0) -> Tuple[torch.Tensor, torch.Tens
6464
device=probabilities.device)
6565
else:
6666
# If no snap_value is provided, keep original probabilities for those passing threshold.
67-
lower_val = probabilities
68-
upper_val = probabilities
67+
lower_val = probabilities if opts.set_lower_bound else torch.zeros_like(probabilities)
68+
upper_val = probabilities if opts.set_upper_bound else torch.ones_like(probabilities)
6969

7070
# For probabilities that pass the threshold, apply the above; else, bounds are fixed to [0,1].
7171
lower_bounds = torch.where(condition, lower_val, torch.zeros_like(probabilities))

0 commit comments

Comments
 (0)