diff --git a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py index f4f7545d6..2e78994e4 100644 --- a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py +++ b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py @@ -26,6 +26,7 @@ def __init__(self, namespace: Dict[str, Any], ignore: List[str]): self.ignore = ignore self._wrapper_fn_defs: List[ast.FunctionDef] = list() self._local_names = set() + self._wrapped_counter = 0 def auto_wrap(self, tree: ast.Module) -> ast.Module: """ @@ -56,6 +57,14 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: if node.name == "forward": for arg in node.args.args: self._local_names.add(arg.arg) + for arg in node.args.posonlyargs: + self._local_names.add(arg.arg) + for arg in node.args.kwonlyargs: + self._local_names.add(arg.arg) + if node.args.vararg: + self._local_names.add(node.args.vararg.arg) + if node.args.kwarg: + self._local_names.add(node.args.kwarg.arg) return super().generic_visit(node) def visit_Name(self, node: ast.Name): @@ -203,6 +212,11 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign: returns = assigned | conditionally_assigned assert "self" not in args, "Cannot trace self, this should be in the namespace" + # sort arguments for reproducability + args = sorted(args) + kwargs = sorted(kwargs) + returns = sorted(returns) + # build function arguments args_obj = ast.arguments( args=[ast.arg(arg=name) for name in args], @@ -217,14 +231,14 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign: # build body and return statement return_stmt = ast.Return( value=ast.Tuple( - elts=[ast.Name(id=name, ctx=ast.Load()) for name in sorted(returns)], + elts=[ast.Name(id=name, ctx=ast.Load()) for name in returns], ctx=ast.Load(), ) ) body = [node, return_stmt] # build function definition, store in `_wrapper_fn_defs` - fn_name = f"wrapped_{hash(node)}" + fn_name = f"wrapped_{self._wrapped_counter}" fn_def = ast.FunctionDef( name=fn_name, args=args_obj, @@ -232,6 +246,7 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign: decorator_list=[ast.Name(id="torch.fx.wrap", ctx=ast.Load())], ) self._wrapper_fn_defs.append(fn_def) + self._wrapped_counter += 1 # build call and assignment fn_call = ast.Call( @@ -240,13 +255,13 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign: keywords=list(), ) return_tuple = ast.Tuple( - elts=[ast.Name(id=name, ctx=ast.Store()) for name in sorted(returns)], + elts=[ast.Name(id=name, ctx=ast.Store()) for name in returns], ctx=ast.Store(), ) assign_call = ast.Assign(targets=[return_tuple], value=fn_call) # update local names with newly returned values - self._local_names |= returns + self._local_names |= set(returns) # log newly created function definition logger.debug("---- Autowrapper ----") diff --git a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py index 123a334fe..ab2c38161 100644 --- a/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py +++ b/tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py @@ -8,8 +8,7 @@ def check_wrapping( source: str, - output: Optional[str] = None, - num_wrapped: int = 0, + output: str, namespace: Optional[Dict[str, Any]] = None, ignore: Optional[List[str]] = None, ): @@ -20,15 +19,20 @@ def check_wrapping( wrapper = AutoWrapper(namespace, ignore) wrapped = wrapper.auto_wrap(tree) - if output is not None: - wrapped_lines = ast.unparse(wrapped).splitlines() - output_lines = textwrap.dedent(output).splitlines()[1:] - assert wrapped_lines == output_lines + wrapped_lines = ast.unparse(wrapped).splitlines() + output_lines = textwrap.dedent(output).splitlines()[1:] - assert len(wrapper._wrapper_fn_defs) == num_wrapped + assert len(wrapped_lines) == len(output_lines) + for wrapped_line, output_line in zip(wrapped_lines, output_lines): + if "# skip" in output: + continue + + assert wrapped_line == output_line def test_static_if(): + """Checks that resolvable if statements are replaced""" + source = """ def forward(): if 1 + 1 == 2: @@ -39,10 +43,12 @@ def forward(): if True: pass """ - check_wrapping(source, output, 0) + check_wrapping(source, output) def test_static_if_global_vars(): + """Checks that resolvable if statements are replaced""" + source = """ def forward(): if config.is_false: @@ -54,20 +60,35 @@ def forward(): pass """ config = SimpleNamespace(is_false=False) - check_wrapping(source, output, 0, namespace={"config": config}) + check_wrapping(source, output, namespace={"config": config}) def test_dynamic_if(): + """Checks that non-resolvable if statements are ignored""" + source = """ def forward(): test = ... if test: pass """ - check_wrapping(source, None, 1) + output = """ + @torch.fx.wrap + def wrapped_0(test): + if test: + pass + return () + + def forward(): + test = ... + () = wrapped_0(test) + """ + check_wrapping(source, output) def test_ignore_functions(): + """Checks that ignored functions are wrapped""" + def func_one(): pass @@ -79,11 +100,23 @@ def forward(): func_one() func_two() """ + output = """ + @torch.fx.wrap + def wrapped_0(): + return func_one() + return () + + def forward(): + wrapped_0() + func_two() + """ namespace = {"func_one": func_one, "func_two": func_two} - check_wrapping(source, None, 1, namespace=namespace, ignore=["func_one"]) + check_wrapping(source, output, namespace=namespace, ignore=["func_one"]) def test_ignore_methods(): + """Checks that ignored class methods are wrapped""" + class Model: def meth_one(self): pass @@ -96,11 +129,23 @@ def forward(self): self.meth_one() self.meth_two() """ + output = """ + @torch.fx.wrap + def wrapped_0(): + return self.meth_one() + return () + + def forward(self): + wrapped_0() + self.meth_two() + """ namespace = {"self": Model()} - check_wrapping(source, None, 1, namespace=namespace, ignore=["meth_one"]) + check_wrapping(source, output, namespace=namespace, ignore=["meth_one"]) def test_branch_with_self_assignment(): + """Checks that names referenced in self assignment are included in fn args""" + source = """ def forward(x, y): if y > 0: @@ -109,18 +154,38 @@ def forward(x, y): x = x - 1 return x """ + output = """ + @torch.fx.wrap + def wrapped_0(x, y): + if y > 0: + x = x + 1 + else: + x = x - 1 + return (x,) - tree = ast.parse(textwrap.dedent(source)) - wrapper = AutoWrapper(namespace={}, ignore=[]) - wrapper.auto_wrap(tree) + def forward(x, y): + (x,) = wrapped_0(x, y) # skip: some envs use "(x,)" -> "x," + return x + """ + check_wrapping(source, output) - assert len(wrapper._wrapper_fn_defs) == 1 - # Check if both x, y are included in args - wrapped_fn = wrapper._wrapper_fn_defs[0] - arg_names = {arg.arg for arg in wrapped_fn.args.args} +def test_function_variadic(): + """Checks for handling variadic names created via function def""" + + source = """ + def forward(a, *b, c=5, **d): + if a == b and c == d: + pass + """ + output = """ + @torch.fx.wrap + def wrapped_0(a, b, c, d): + if a == b and c == d: + pass + return () - assert arg_names == { - "x", - "y", - }, f"Expected arguments {{'x', 'y'}}, but got {arg_names}" + def forward(a, *b, c=5, **d): + () = wrapped_0(a, b, c, d) + """ + check_wrapping(source, output)