File tree Expand file tree Collapse file tree 2 files changed +4
-32
lines changed
test/python/fx_importer/v2.3 Expand file tree Collapse file tree 2 files changed +4
-32
lines changed Original file line number Diff line number Diff line change @@ -723,17 +723,10 @@ def import_program(
723
723
# on a symbolic or other non-SSA association. As such, they
724
724
# are not modeled with mutable IR but will trigger an output
725
725
# store hook when the final value is produced.
726
- if input_spec .persistent :
727
- value = prog .state_dict .get (input_spec .target )
728
- assert (
729
- value is not None
730
- ), "Expected state_dict value for persistent buffer"
731
- else :
732
- value = prog .constants .get (input_spec .target )
733
- assert (
734
- value is not None
735
- ), "Expected constants value for non-persistent buffer"
736
-
726
+ value = prog .state_dict .get (input_spec .target )
727
+ assert (
728
+ not input_spec .persistent or value is not None
729
+ ), "Expected state_dict value for persistent value"
737
730
node = placeholder_nodes [arg .name ]
738
731
mutable_producer_node_name = mutable_buffer_target_producers .get (
739
732
input_spec .target
Original file line number Diff line number Diff line change @@ -107,27 +107,6 @@ def forward(self, x):
107
107
m .operation .verify ()
108
108
109
109
110
- @run
111
- # CHECK-LABEL: test_frozen_buffer_non_persistent
112
- # CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal
113
- # CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0
114
- # CHECK: return %[[mul]]
115
- def test_frozen_buffer_non_persistent ():
116
- class Basic (nn .Module ):
117
- def __init__ (self ):
118
- super ().__init__ ()
119
- self .register_buffer ("buffer" , torch .randn (3 , 4 ), persistent = False )
120
-
121
- def forward (self , x ):
122
- return x * self .buffer
123
-
124
- m = fx .export_and_import (
125
- Basic (), torch .randn (3 , 4 ), experimental_support_mutation = True
126
- )
127
- print (m )
128
- m .operation .verify ()
129
-
130
-
131
110
class ExternalBufferHooks (fx .FxImporterHooks ):
132
111
def prepare_module (self , module_op : Operation ):
133
112
module_op .context .allow_unregistered_dialects = True
You can’t perform that action at this time.
0 commit comments