Skip to content

Commit 9a89cd0

Browse files
Revert "[Fx Importer] fix mutation importer with non persistent buffer (#3798)"
This reverts commit 8f52f5a.
1 parent e44ea22 commit 9a89cd0

File tree

2 files changed

+4
-32
lines changed

2 files changed

+4
-32
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -723,17 +723,10 @@ def import_program(
723723
# on a symbolic or other non-SSA association. As such, they
724724
# are not modeled with mutable IR but will trigger an output
725725
# store hook when the final value is produced.
726-
if input_spec.persistent:
727-
value = prog.state_dict.get(input_spec.target)
728-
assert (
729-
value is not None
730-
), "Expected state_dict value for persistent buffer"
731-
else:
732-
value = prog.constants.get(input_spec.target)
733-
assert (
734-
value is not None
735-
), "Expected constants value for non-persistent buffer"
736-
726+
value = prog.state_dict.get(input_spec.target)
727+
assert (
728+
not input_spec.persistent or value is not None
729+
), "Expected state_dict value for persistent value"
737730
node = placeholder_nodes[arg.name]
738731
mutable_producer_node_name = mutable_buffer_target_producers.get(
739732
input_spec.target

test/python/fx_importer/v2.3/mutation_import.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,27 +107,6 @@ def forward(self, x):
107107
m.operation.verify()
108108

109109

110-
@run
111-
# CHECK-LABEL: test_frozen_buffer_non_persistent
112-
# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal
113-
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0
114-
# CHECK: return %[[mul]]
115-
def test_frozen_buffer_non_persistent():
116-
class Basic(nn.Module):
117-
def __init__(self):
118-
super().__init__()
119-
self.register_buffer("buffer", torch.randn(3, 4), persistent=False)
120-
121-
def forward(self, x):
122-
return x * self.buffer
123-
124-
m = fx.export_and_import(
125-
Basic(), torch.randn(3, 4), experimental_support_mutation=True
126-
)
127-
print(m)
128-
m.operation.verify()
129-
130-
131110
class ExternalBufferHooks(fx.FxImporterHooks):
132111
def prepare_module(self, module_op: Operation):
133112
module_op.context.allow_unregistered_dialects = True

0 commit comments

Comments
 (0)