File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -54,7 +54,7 @@ def check_net_device(
54
54
55
55
if isinstance (net , nn .Identity ):
56
56
return net
57
- if str (next (net .parameters ()).device ) != device :
57
+ if str (next (net .parameters ()).device ) != str ( device ) :
58
58
warn (
59
59
message or f"Network is not on the correct device. Moving it to { device } ." ,
60
60
stacklevel = 2 ,
Original file line number Diff line number Diff line change @@ -715,7 +715,7 @@ def validate_theta_and_x(
715
715
assert theta .dtype == float32 , "Type of parameters must be float32."
716
716
assert x .dtype == float32 , "Type of simulator outputs must be float32."
717
717
718
- if str (x .device ) != data_device :
718
+ if str (x .device ) != str ( data_device ) :
719
719
warnings .warn (
720
720
f"Data x has device '{ x .device } '. "
721
721
f"Moving x to the data_device '{ data_device } '. "
@@ -724,7 +724,7 @@ def validate_theta_and_x(
724
724
)
725
725
x = x .to (data_device )
726
726
727
- if str (theta .device ) != data_device :
727
+ if str (theta .device ) != str ( data_device ) :
728
728
warnings .warn (
729
729
f"Parameters theta has device '{ theta .device } '. "
730
730
f"Moving theta to the data_device '{ data_device } '. "
You can’t perform that action at this time.
0 commit comments