Skip to content

[Autowrapper] Fix local names, increase reproducability #1672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, namespace: Dict[str, Any], ignore: List[str]):
self.ignore = ignore
self._wrapper_fn_defs: List[ast.FunctionDef] = list()
self._local_names = set()
self._wrapped_counter = 0

def auto_wrap(self, tree: ast.Module) -> ast.Module:
"""
Expand Down Expand Up @@ -56,6 +57,14 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
if node.name == "forward":
for arg in node.args.args:
self._local_names.add(arg.arg)
for arg in node.args.posonlyargs:
self._local_names.add(arg.arg)
for arg in node.args.kwonlyargs:
self._local_names.add(arg.arg)
if node.args.vararg:
self._local_names.add(node.args.vararg.arg)
if node.args.kwarg:
self._local_names.add(node.args.kwarg.arg)
return super().generic_visit(node)

def visit_Name(self, node: ast.Name):
Expand Down Expand Up @@ -203,6 +212,11 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
returns = assigned | conditionally_assigned
assert "self" not in args, "Cannot trace self, this should be in the namespace"

# sort arguments for reproducability
args = sorted(args)
kwargs = sorted(kwargs)
returns = sorted(returns)

# build function arguments
args_obj = ast.arguments(
args=[ast.arg(arg=name) for name in args],
Expand All @@ -217,21 +231,22 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
# build body and return statement
return_stmt = ast.Return(
value=ast.Tuple(
elts=[ast.Name(id=name, ctx=ast.Load()) for name in sorted(returns)],
elts=[ast.Name(id=name, ctx=ast.Load()) for name in returns],
ctx=ast.Load(),
)
)
body = [node, return_stmt]

# build function definition, store in `_wrapper_fn_defs`
fn_name = f"wrapped_{hash(node)}"
fn_name = f"wrapped_{self._wrapped_counter}"
fn_def = ast.FunctionDef(
name=fn_name,
args=args_obj,
body=body,
decorator_list=[ast.Name(id="torch.fx.wrap", ctx=ast.Load())],
)
self._wrapper_fn_defs.append(fn_def)
self._wrapped_counter += 1

# build call and assignment
fn_call = ast.Call(
Expand All @@ -240,13 +255,13 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
keywords=list(),
)
return_tuple = ast.Tuple(
elts=[ast.Name(id=name, ctx=ast.Store()) for name in sorted(returns)],
elts=[ast.Name(id=name, ctx=ast.Store()) for name in returns],
ctx=ast.Store(),
)
assign_call = ast.Assign(targets=[return_tuple], value=fn_call)

# update local names with newly returned values
self._local_names |= returns
self._local_names |= set(returns)

# log newly created function definition
logger.debug("---- Autowrapper ----")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

def check_wrapping(
source: str,
output: Optional[str] = None,
num_wrapped: int = 0,
output: str,
namespace: Optional[Dict[str, Any]] = None,
ignore: Optional[List[str]] = None,
):
Expand All @@ -20,15 +19,20 @@ def check_wrapping(
wrapper = AutoWrapper(namespace, ignore)
wrapped = wrapper.auto_wrap(tree)

if output is not None:
wrapped_lines = ast.unparse(wrapped).splitlines()
output_lines = textwrap.dedent(output).splitlines()[1:]
assert wrapped_lines == output_lines
wrapped_lines = ast.unparse(wrapped).splitlines()
output_lines = textwrap.dedent(output).splitlines()[1:]

assert len(wrapper._wrapper_fn_defs) == num_wrapped
assert len(wrapped_lines) == len(output_lines)
for wrapped_line, output_line in zip(wrapped_lines, output_lines):
if "# skip" in output:
continue

assert wrapped_line == output_line


def test_static_if():
"""Checks that resolvable if statements are replaced"""

source = """
def forward():
if 1 + 1 == 2:
Expand All @@ -39,10 +43,12 @@ def forward():
if True:
pass
"""
check_wrapping(source, output, 0)
check_wrapping(source, output)


def test_static_if_global_vars():
"""Checks that resolvable if statements are replaced"""

source = """
def forward():
if config.is_false:
Expand All @@ -54,20 +60,35 @@ def forward():
pass
"""
config = SimpleNamespace(is_false=False)
check_wrapping(source, output, 0, namespace={"config": config})
check_wrapping(source, output, namespace={"config": config})


def test_dynamic_if():
"""Checks that non-resolvable if statements are ignored"""

source = """
def forward():
test = ...
if test:
pass
"""
check_wrapping(source, None, 1)
output = """
@torch.fx.wrap
def wrapped_0(test):
if test:
pass
return ()

def forward():
test = ...
() = wrapped_0(test)
"""
check_wrapping(source, output)


def test_ignore_functions():
"""Checks that ignored functions are wrapped"""

def func_one():
pass

Expand All @@ -79,11 +100,23 @@ def forward():
func_one()
func_two()
"""
output = """
@torch.fx.wrap
def wrapped_0():
return func_one()
return ()

def forward():
wrapped_0()
func_two()
"""
namespace = {"func_one": func_one, "func_two": func_two}
check_wrapping(source, None, 1, namespace=namespace, ignore=["func_one"])
check_wrapping(source, output, namespace=namespace, ignore=["func_one"])


def test_ignore_methods():
"""Checks that ignored class methods are wrapped"""

class Model:
def meth_one(self):
pass
Expand All @@ -96,11 +129,23 @@ def forward(self):
self.meth_one()
self.meth_two()
"""
output = """
@torch.fx.wrap
def wrapped_0():
return self.meth_one()
return ()

def forward(self):
wrapped_0()
self.meth_two()
"""
namespace = {"self": Model()}
check_wrapping(source, None, 1, namespace=namespace, ignore=["meth_one"])
check_wrapping(source, output, namespace=namespace, ignore=["meth_one"])


def test_branch_with_self_assignment():
"""Checks that names referenced in self assignment are included in fn args"""

source = """
def forward(x, y):
if y > 0:
Expand All @@ -109,18 +154,38 @@ def forward(x, y):
x = x - 1
return x
"""
output = """
@torch.fx.wrap
def wrapped_0(x, y):
if y > 0:
x = x + 1
else:
x = x - 1
return (x,)

tree = ast.parse(textwrap.dedent(source))
wrapper = AutoWrapper(namespace={}, ignore=[])
wrapper.auto_wrap(tree)
def forward(x, y):
(x,) = wrapped_0(x, y) # skip: some envs use "(x,)" -> "x,"
return x
"""
check_wrapping(source, output)

assert len(wrapper._wrapper_fn_defs) == 1

# Check if both x, y are included in args
wrapped_fn = wrapper._wrapper_fn_defs[0]
arg_names = {arg.arg for arg in wrapped_fn.args.args}
def test_function_variadic():
"""Checks for handling variadic names created via function def"""

source = """
def forward(a, *b, c=5, **d):
if a == b and c == d:
pass
"""
output = """
@torch.fx.wrap
def wrapped_0(a, b, c, d):
if a == b and c == d:
pass
return ()

assert arg_names == {
"x",
"y",
}, f"Expected arguments {{'x', 'y'}}, but got {arg_names}"
def forward(a, *b, c=5, **d):
() = wrapped_0(a, b, c, d)
"""
check_wrapping(source, output)