Skip to content

Commit 13d40f6

Browse files
Revert "hang dim hint constants off Dim (pytorch#134484)"
This reverts commit c142af7. Reverted pytorch#134484 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#134484 (comment)))
1 parent 2c88a92 commit 13d40f6

File tree

4 files changed

+53
-46
lines changed

4 files changed

+53
-46
lines changed

docs/source/export.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ API Reference
676676
.. autofunction:: save
677677
.. autofunction:: load
678678
.. autofunction:: register_dataclass
679+
.. autoclass:: torch.export.dynamic_shapes.DIM
679680
.. autofunction:: torch.export.dynamic_shapes.Dim
680681
.. autofunction:: dims
681682
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection

test/export/test_export.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,8 @@ def forward(self, x, y, y1, z):
17731773
)
17741774

17751775
def test_static_dim_constraints(self):
1776+
from torch.export.dynamic_shapes import DIM
1777+
17761778
class Foo(torch.nn.Module):
17771779
def __init__(self) -> None:
17781780
super().__init__()
@@ -1801,7 +1803,7 @@ def forward(self, x, y, z):
18011803
((dx, None), (dy, 4), (dz, 3)),
18021804
((None, 6), (5, None), (None, None)),
18031805
((4, 6), {0: None, 1: 4}, {0: None, 1: 3}),
1804-
(None, None, (Dim.STATIC, Dim.STATIC)),
1806+
(None, None, (DIM.STATIC, DIM.STATIC)),
18051807
]:
18061808
ep = export(foo, inputs, dynamic_shapes=dynamic_shapes)
18071809
self.assertEqual(foo(*inputs), ep.module()(*inputs))
@@ -1948,7 +1950,9 @@ def forward(self, inp: Inp):
19481950
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
19491951

19501952
def test_mismatched_dynamic_shapes(self):
1951-
AUTO, STATIC = Dim.AUTO, Dim.STATIC
1953+
from torch.export.dynamic_shapes import DIM
1954+
1955+
AUTO, STATIC = DIM.AUTO, DIM.STATIC
19521956

19531957
class M(torch.nn.Module):
19541958
def forward(self, x):
@@ -1980,7 +1984,7 @@ def forward(self, x):
19801984
+ re.escape(
19811985
"specified at `dynamic_shapes[0]['k']['k'][0]` "
19821986
"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
1983-
" where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)"
1987+
" where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)"
19841988
),
19851989
):
19861990
export(M(), inputs, dynamic_shapes=dynamic_shapes)
@@ -2057,7 +2061,7 @@ def forward(self, x):
20572061
with self.assertRaisesRegex(
20582062
torch._dynamo.exc.UserError,
20592063
re.escape(
2060-
"Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
2064+
"Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
20612065
"and can easily lead to constraint violation errors or obscure errors in torch.export."
20622066
),
20632067
):
@@ -2461,7 +2465,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
24612465
def test_mark_and_auto_dynamic(self):
24622466
# for this use case, mark_dynamic() and AUTO should have same effect.
24632467
# check that same symbol gets allocated to both dims without raising constraint violation.
2464-
AUTO, STATIC = Dim.AUTO, Dim.STATIC
2468+
from torch.export.dynamic_shapes import DIM
2469+
2470+
AUTO, STATIC = DIM.AUTO, DIM.STATIC
24652471

24662472
class Foo(torch.nn.Module):
24672473
def forward(self, x, y):
@@ -2489,7 +2495,9 @@ def forward(self, x, y):
24892495
def test_dont_duck_size_for_auto_dynamic(self):
24902496
# for this use case, mark_dynamic() and AUTO should have same effect.
24912497
# check that same symbol gets allocated to both dims without raising constraint violation.
2492-
AUTO, STATIC = Dim.AUTO, Dim.STATIC
2498+
from torch.export.dynamic_shapes import DIM
2499+
2500+
AUTO, STATIC = DIM.AUTO, DIM.STATIC
24932501

24942502
class Foo(torch.nn.Module):
24952503
def forward(self, x, y):
@@ -6868,7 +6876,9 @@ def test_automatic_dynamic_shapes_simple_equality(self):
68686876
# The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism
68696877
# leads to replacement symbols being set for equalities, and inferred relationships being checked
68706878
# with runtime asserts. Check that we specialize to static values when the program says so.
6871-
AUTO, STATIC = Dim.AUTO, Dim.STATIC
6879+
from torch.export.dynamic_shapes import DIM
6880+
6881+
AUTO, STATIC = DIM.AUTO, DIM.STATIC
68726882

68736883
# case 1: direct equality between symbols
68746884
class SimpleEquality(torch.nn.Module):
@@ -6933,7 +6943,9 @@ def forward(self, x, y, z):
69336943
)
69346944

69356945
def test_automatic_dynamic_shapes_constant_relation(self):
6936-
AUTO, STATIC = Dim.AUTO, Dim.STATIC
6946+
from torch.export.dynamic_shapes import DIM
6947+
6948+
AUTO, STATIC = DIM.AUTO, DIM.STATIC
69376949

69386950
# case 2: related by constant: s0 + 4 = s1
69396951
class OffBy4(torch.nn.Module):
@@ -6976,7 +6988,9 @@ def forward(self, x, y):
69766988
)
69776989

69786990
def test_automatic_dynamic_shapes_linear_relation(self):
6979-
AUTO, STATIC = Dim.AUTO, Dim.STATIC
6991+
from torch.export.dynamic_shapes import DIM
6992+
6993+
AUTO, STATIC = DIM.AUTO, DIM.STATIC
69806994

69816995
# case 3: linear relation
69826996
class LinearRel(torch.nn.Module):

torch/_export/non_strict_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
from torch.export.dynamic_shapes import (
2525
_check_dynamic_shapes,
2626
_combine_args,
27-
_DimHint,
2827
_process_dynamic_shapes,
2928
_transform_shapes_for_default_dynamic,
3029
_tree_map_with_path,
30+
DIM,
3131
)
3232
from torch.export.graph_signature import CustomObjArgument
3333
from torch.fx.experimental import _config as config
@@ -351,7 +351,7 @@ def make_constraints(
351351
# we want the symbol, not its replacement, which could be an expression. Maybe
352352
# there's a better way to do this, e.g., by (re)computing value ranges for expressions?
353353
dim = shape_spec[i] if shape_spec else None
354-
if dim is None or isinstance(dim, _DimHint):
354+
if dim is None or isinstance(dim, DIM):
355355
range_constraints[d.node.expr] = shape_env.var_to_range[
356356
d.node._expr
357357
]

torch/export/dynamic_shapes.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,20 @@
3030

3131
__all__ = [
3232
"Constraint",
33+
"DIM",
3334
"Dim",
3435
"dims",
3536
"refine_dynamic_shapes_from_suggested_fixes",
3637
]
3738

3839

39-
class _DimHint(Enum):
40+
class DIM(Enum):
4041
"""
41-
Enum for dynamic shape hints.
42-
- AUTO means automatic inference of shape (static or dynamic).
43-
- STATIC means static shape (always specialized).
42+
Enum for automatic/static dynamic shapes.
4443
"""
4544

46-
AUTO = auto()
4745
STATIC = auto()
46+
AUTO = auto()
4847

4948

5049
class _Dim(type):
@@ -215,7 +214,6 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
215214
Returns:
216215
A type that can be used in dynamic shape specifications for tensors.
217216
"""
218-
219217
from torch.utils._sympy.numbers import int_oo
220218

221219
_min = 0 if min is None else min
@@ -229,10 +227,6 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
229227
return dim
230228

231229

232-
Dim.AUTO = _DimHint.AUTO # hint for automatic inference of shape (static or dynamic)
233-
Dim.STATIC = _DimHint.STATIC # hint for static shape (always specialized)
234-
235-
236230
def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None):
237231
"""
238232
Util to create multiple :func:`Dim` types.
@@ -674,32 +668,32 @@ def check_symbols(path, tensor, shape):
674668
for i, dim in shape.items():
675669
if isinstance(dim, _Dim):
676670
check_same_bounds(dim)
677-
elif not (isinstance(dim, (int, _DimHint)) or dim is None):
671+
elif not (isinstance(dim, (int, DIM)) or dim is None):
678672
raise UserError(
679673
UserErrorType.INVALID_INPUT,
680674
f"Unexpected dimension mapped to index {i} in input tensor shape {shape} "
681675
f"specified at `dynamic_shapes{keystr(path)}` "
682-
f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)",
676+
f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)",
683677
case_name="dynamic_shapes_validation",
684678
)
685679
elif isinstance(shape, (tuple, list)):
686680
for i, dim in enumerate(shape):
687681
if isinstance(dim, _Dim):
688682
check_same_bounds(dim)
689-
elif not (isinstance(dim, (int, _DimHint)) or dim is None):
683+
elif not (isinstance(dim, (int, DIM)) or dim is None):
690684
raise UserError(
691685
UserErrorType.INVALID_INPUT,
692686
f"Unexpected dimension #{i} in input tensor shape {shape} "
693687
f"specified at `dynamic_shapes{keystr(path)}` "
694-
f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)",
688+
f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)",
695689
case_name="dynamic_shapes_validation",
696690
)
697691
elif shape is not None:
698692
raise UserError(
699693
UserErrorType.INVALID_INPUT,
700694
f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` "
701695
f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
702-
f" where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)",
696+
f" where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)",
703697
case_name="dynamic_shapes_validation",
704698
)
705699

@@ -746,18 +740,18 @@ def check_shape(path, t, dynamic_shape):
746740

747741
_tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
748742

749-
# raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes
743+
# raise user warning if both DIM.AUTO & Dims are specified in dynamic_shapes
750744
flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
751745
flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes)
752746
if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any(
753-
s == _DimHint.AUTO for s in flatter_dynamic_shapes
747+
s == DIM.AUTO for s in flatter_dynamic_shapes
754748
):
755749
raise UserError(
756750
UserErrorType.INVALID_INPUT,
757-
"Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
751+
"Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
758752
"and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims "
759-
"expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. "
760-
"We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), "
753+
"expect all equal or related dimensions to be specified, and does not yet compose well with `DIM.AUTO`. "
754+
"We suggest using `DIM.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), "
761755
"torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` "
762756
"if you want to assert on the exact specification of your program's dynamic shapes behavior.",
763757
case_name="dynamic_shapes_validation",
@@ -779,8 +773,8 @@ def _transform_shapes_for_default_dynamic(
779773
for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.)
780774
781775
For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are:
782-
- Dim.AUTO: dynamic, allocated a symbol
783-
- None/unspecified/Dim.STATIC: static
776+
- DIM.AUTO: dynamic, allocated a symbol
777+
- None/unspecified/DIM.STATIC: static
784778
- Dim/DerivedDims: also a strict assertion
785779
786780
To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes
@@ -790,8 +784,8 @@ def _transform_shapes_for_default_dynamic(
790784
An example conversion might look like, for a 3-d input tensor:
791785
792786
input spec: {
793-
0: Dim.AUTO,
794-
1: None, # or Dim.STATIC
787+
0: DIM.AUTO,
788+
1: None, # or DIM.STATIC
795789
2: Dim("dx"),
796790
}
797791
output spec: {
@@ -838,10 +832,10 @@ def _marked_dynamic(tensor, i):
838832
out = {}
839833
for i, val in enumerate(tensor.shape):
840834
dim = shape.get(i, None)
841-
if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO:
835+
if _marked_dynamic(tensor, i) or dim == DIM.AUTO:
842836
# don't have to specify anything if dynamic
843837
# None also works, since assume_static_by_default=False
844-
if dim == _DimHint.AUTO:
838+
if dim == DIM.AUTO:
845839
torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing
846840
continue
847841
elif isinstance(dim, _Dim):
@@ -852,22 +846,22 @@ def _marked_dynamic(tensor, i):
852846
out[i] = dim
853847
else:
854848
# make explicitly static
855-
assert dim is None or dim == _DimHint.STATIC
849+
assert dim is None or dim == DIM.STATIC
856850
out[i] = val
857851
elif isinstance(shape, (tuple, list)):
858852
out = []
859853
for i, val in enumerate(tensor.shape):
860854
dim = shape[i]
861-
if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO:
862-
if dim == _DimHint.AUTO:
855+
if _marked_dynamic(tensor, i) or dim == DIM.AUTO:
856+
if dim == DIM.AUTO:
863857
torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing
864858
out.append(None)
865859
elif isinstance(dim, _Dim):
866860
out.append(dim)
867861
elif isinstance(dim, int):
868862
out.append(dim)
869863
else:
870-
assert dim is None or dim == _DimHint.STATIC
864+
assert dim is None or dim == DIM.STATIC
871865
out.append(val)
872866
out = type(shape)(out) # type: ignore[assignment]
873867
else:
@@ -1046,7 +1040,7 @@ def _get_dim_name_mapping(
10461040
dynamic_shapes,
10471041
is_leaf=lambda x: isinstance(x, _Dim),
10481042
)[0]:
1049-
if isinstance(dim, (int, _DimHint)) or dim is None:
1043+
if isinstance(dim, (int, DIM)) or dim is None:
10501044
continue
10511045
name_to_dim[dim.__name__] = dim
10521046
if isinstance(dim, _DerivedDim):
@@ -1128,11 +1122,9 @@ def refine_dynamic_shapes_from_suggested_fixes(
11281122
name, expr = fix.split(" = ")
11291123
expr = sympy.sympify(expr)
11301124
if isinstance(expr, sympy.Number):
1131-
# static, integer
1132-
shape_fixes[name] = int(expr) # type: ignore[assignment]
1125+
shape_fixes[name] = int(expr) # static, integer
11331126
else:
1134-
# relation or derived dim
1135-
shape_fixes[name] = expr
1127+
shape_fixes[name] = expr # relation or derived dim
11361128

11371129
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
11381130

0 commit comments

Comments
 (0)