Skip to content

Commit 653d0b4

Browse files
Fix type checks for EnsemblePosterior weights (#1299)
* minor fix for EnsemblePosterior weights.setter * Update sbi/inference/posteriors/ensemble_posterior.py Co-authored-by: Jan <janfb@users.noreply.github.com> --------- Co-authored-by: Jan <janfb@users.noreply.github.com>
1 parent 18b1141 commit 653d0b4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

sbi/inference/posteriors/ensemble_posterior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def weights(self, weights: Optional[Union[List[float], Tensor]]) -> None:
137137
self._weights = torch.tensor([
138138
1.0 / self.num_components for _ in range(self.num_components)
139139
])
140-
elif weights is Tensor or weights is List:
140+
elif isinstance(weights, (Tensor, List)):
141141
self._weights = torch.tensor(weights) / sum(weights)
142142
else:
143143
raise TypeError

0 commit comments

Comments
 (0)