Skip to content

Commit b3254ed

Browse files
Ensure device comparison always between string representations (#1225)
1 parent dcfdf35 commit b3254ed

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

sbi/utils/nn_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def check_net_device(
5454

5555
if isinstance(net, nn.Identity):
5656
return net
57-
if str(next(net.parameters()).device) != device:
57+
if str(next(net.parameters()).device) != str(device):
5858
warn(
5959
message or f"Network is not on the correct device. Moving it to {device}.",
6060
stacklevel=2,

sbi/utils/user_input_checks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def validate_theta_and_x(
715715
assert theta.dtype == float32, "Type of parameters must be float32."
716716
assert x.dtype == float32, "Type of simulator outputs must be float32."
717717

718-
if str(x.device) != data_device:
718+
if str(x.device) != str(data_device):
719719
warnings.warn(
720720
f"Data x has device '{x.device}'. "
721721
f"Moving x to the data_device '{data_device}'. "
@@ -724,7 +724,7 @@ def validate_theta_and_x(
724724
)
725725
x = x.to(data_device)
726726

727-
if str(theta.device) != data_device:
727+
if str(theta.device) != str(data_device):
728728
warnings.warn(
729729
f"Parameters theta has device '{theta.device}'. "
730730
f"Moving theta to the data_device '{data_device}'. "

0 commit comments

Comments
 (0)