30
30
31
31
__all__ = [
32
32
"Constraint" ,
33
+ "DIM" ,
33
34
"Dim" ,
34
35
"dims" ,
35
36
"refine_dynamic_shapes_from_suggested_fixes" ,
36
37
]
37
38
38
39
39
- class _DimHint (Enum ):
40
+ class DIM (Enum ):
40
41
"""
41
- Enum for dynamic shape hints.
42
- - AUTO means automatic inference of shape (static or dynamic).
43
- - STATIC means static shape (always specialized).
42
+ Enum for automatic/static dynamic shapes.
44
43
"""
45
44
46
- AUTO = auto ()
47
45
STATIC = auto ()
46
+ AUTO = auto ()
48
47
49
48
50
49
class _Dim (type ):
@@ -215,7 +214,6 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
215
214
Returns:
216
215
A type that can be used in dynamic shape specifications for tensors.
217
216
"""
218
-
219
217
from torch .utils ._sympy .numbers import int_oo
220
218
221
219
_min = 0 if min is None else min
@@ -229,10 +227,6 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
229
227
return dim
230
228
231
229
232
- Dim .AUTO = _DimHint .AUTO # hint for automatic inference of shape (static or dynamic)
233
- Dim .STATIC = _DimHint .STATIC # hint for static shape (always specialized)
234
-
235
-
236
230
def dims (* names : str , min : Optional [int ] = None , max : Optional [int ] = None ):
237
231
"""
238
232
Util to create multiple :func:`Dim` types.
@@ -674,32 +668,32 @@ def check_symbols(path, tensor, shape):
674
668
for i , dim in shape .items ():
675
669
if isinstance (dim , _Dim ):
676
670
check_same_bounds (dim )
677
- elif not (isinstance (dim , (int , _DimHint )) or dim is None ):
671
+ elif not (isinstance (dim , (int , DIM )) or dim is None ):
678
672
raise UserError (
679
673
UserErrorType .INVALID_INPUT ,
680
674
f"Unexpected dimension mapped to index { i } in input tensor shape { shape } "
681
675
f"specified at `dynamic_shapes{ keystr (path )} ` "
682
- f"(expected None, an int, a Dim, Dim .AUTO, or Dim .STATIC, but got { dim } instead)" ,
676
+ f"(expected None, an int, a Dim, DIM .AUTO, or DIM .STATIC, but got { dim } instead)" ,
683
677
case_name = "dynamic_shapes_validation" ,
684
678
)
685
679
elif isinstance (shape , (tuple , list )):
686
680
for i , dim in enumerate (shape ):
687
681
if isinstance (dim , _Dim ):
688
682
check_same_bounds (dim )
689
- elif not (isinstance (dim , (int , _DimHint )) or dim is None ):
683
+ elif not (isinstance (dim , (int , DIM )) or dim is None ):
690
684
raise UserError (
691
685
UserErrorType .INVALID_INPUT ,
692
686
f"Unexpected dimension #{ i } in input tensor shape { shape } "
693
687
f"specified at `dynamic_shapes{ keystr (path )} ` "
694
- f"(expected None, an int, a Dim, Dim .AUTO, or Dim .STATIC, but got { dim } instead)" ,
688
+ f"(expected None, an int, a Dim, DIM .AUTO, or DIM .STATIC, but got { dim } instead)" ,
695
689
case_name = "dynamic_shapes_validation" ,
696
690
)
697
691
elif shape is not None :
698
692
raise UserError (
699
693
UserErrorType .INVALID_INPUT ,
700
694
f"Unexpected input tensor shape { shape } specified at `dynamic_shapes{ keystr (path )} ` "
701
695
f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
702
- f" where each dimension is None, an int, a Dim, Dim .AUTO, or Dim .STATIC)" ,
696
+ f" where each dimension is None, an int, a Dim, DIM .AUTO, or DIM .STATIC)" ,
703
697
case_name = "dynamic_shapes_validation" ,
704
698
)
705
699
@@ -746,18 +740,18 @@ def check_shape(path, t, dynamic_shape):
746
740
747
741
_tree_map_with_path (check_shape , combined_args , dynamic_shapes , tree_name = "inputs" )
748
742
749
- # raise user warning if both Dim .AUTO & Dims are specified in dynamic_shapes
743
+ # raise user warning if both DIM .AUTO & Dims are specified in dynamic_shapes
750
744
flat_dynamic_shapes = _flatten_dynamic_shapes (combined_args , dynamic_shapes )
751
745
flatter_dynamic_shapes , _ = tree_flatten (flat_dynamic_shapes )
752
746
if any (isinstance (s , _Dim ) for s in flatter_dynamic_shapes ) and any (
753
- s == _DimHint .AUTO for s in flatter_dynamic_shapes
747
+ s == DIM .AUTO for s in flatter_dynamic_shapes
754
748
):
755
749
raise UserError (
756
750
UserErrorType .INVALID_INPUT ,
757
- "Specifying both `Dim .AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
751
+ "Specifying both `DIM .AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
758
752
"and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims "
759
- "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim .AUTO`. "
760
- "We suggest using `Dim .AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), "
753
+ "expect all equal or related dimensions to be specified, and does not yet compose well with `DIM .AUTO`. "
754
+ "We suggest using `DIM .AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), "
761
755
"torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` "
762
756
"if you want to assert on the exact specification of your program's dynamic shapes behavior." ,
763
757
case_name = "dynamic_shapes_validation" ,
@@ -779,8 +773,8 @@ def _transform_shapes_for_default_dynamic(
779
773
for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.)
780
774
781
775
For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are:
782
- - Dim .AUTO: dynamic, allocated a symbol
783
- - None/unspecified/Dim .STATIC: static
776
+ - DIM .AUTO: dynamic, allocated a symbol
777
+ - None/unspecified/DIM .STATIC: static
784
778
- Dim/DerivedDims: also a strict assertion
785
779
786
780
To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes
@@ -790,8 +784,8 @@ def _transform_shapes_for_default_dynamic(
790
784
An example conversion might look like, for a 3-d input tensor:
791
785
792
786
input spec: {
793
- 0: Dim .AUTO,
794
- 1: None, # or Dim .STATIC
787
+ 0: DIM .AUTO,
788
+ 1: None, # or DIM .STATIC
795
789
2: Dim("dx"),
796
790
}
797
791
output spec: {
@@ -838,10 +832,10 @@ def _marked_dynamic(tensor, i):
838
832
out = {}
839
833
for i , val in enumerate (tensor .shape ):
840
834
dim = shape .get (i , None )
841
- if _marked_dynamic (tensor , i ) or dim == _DimHint .AUTO :
835
+ if _marked_dynamic (tensor , i ) or dim == DIM .AUTO :
842
836
# don't have to specify anything if dynamic
843
837
# None also works, since assume_static_by_default=False
844
- if dim == _DimHint .AUTO :
838
+ if dim == DIM .AUTO :
845
839
torch ._dynamo .maybe_mark_dynamic (tensor , i ) # avoid duck sizing
846
840
continue
847
841
elif isinstance (dim , _Dim ):
@@ -852,22 +846,22 @@ def _marked_dynamic(tensor, i):
852
846
out [i ] = dim
853
847
else :
854
848
# make explicitly static
855
- assert dim is None or dim == _DimHint .STATIC
849
+ assert dim is None or dim == DIM .STATIC
856
850
out [i ] = val
857
851
elif isinstance (shape , (tuple , list )):
858
852
out = []
859
853
for i , val in enumerate (tensor .shape ):
860
854
dim = shape [i ]
861
- if _marked_dynamic (tensor , i ) or dim == _DimHint .AUTO :
862
- if dim == _DimHint .AUTO :
855
+ if _marked_dynamic (tensor , i ) or dim == DIM .AUTO :
856
+ if dim == DIM .AUTO :
863
857
torch ._dynamo .maybe_mark_dynamic (tensor , i ) # avoid duck sizing
864
858
out .append (None )
865
859
elif isinstance (dim , _Dim ):
866
860
out .append (dim )
867
861
elif isinstance (dim , int ):
868
862
out .append (dim )
869
863
else :
870
- assert dim is None or dim == _DimHint .STATIC
864
+ assert dim is None or dim == DIM .STATIC
871
865
out .append (val )
872
866
out = type (shape )(out ) # type: ignore[assignment]
873
867
else :
@@ -1046,7 +1040,7 @@ def _get_dim_name_mapping(
1046
1040
dynamic_shapes ,
1047
1041
is_leaf = lambda x : isinstance (x , _Dim ),
1048
1042
)[0 ]:
1049
- if isinstance (dim , (int , _DimHint )) or dim is None :
1043
+ if isinstance (dim , (int , DIM )) or dim is None :
1050
1044
continue
1051
1045
name_to_dim [dim .__name__ ] = dim
1052
1046
if isinstance (dim , _DerivedDim ):
@@ -1128,11 +1122,9 @@ def refine_dynamic_shapes_from_suggested_fixes(
1128
1122
name , expr = fix .split (" = " )
1129
1123
expr = sympy .sympify (expr )
1130
1124
if isinstance (expr , sympy .Number ):
1131
- # static, integer
1132
- shape_fixes [name ] = int (expr ) # type: ignore[assignment]
1125
+ shape_fixes [name ] = int (expr ) # static, integer
1133
1126
else :
1134
- # relation or derived dim
1135
- shape_fixes [name ] = expr
1127
+ shape_fixes [name ] = expr # relation or derived dim
1136
1128
1137
1129
name_to_dim = _get_dim_name_mapping (dynamic_shapes )
1138
1130
0 commit comments