Skip to content

Commit 970c060

Browse files
authored
Cherrypick DTensor docstring fix for 2.9 release. (#16434)
* Enable the keras dtensor API in OSS. PiperOrigin-RevId: 438858608 * Switching learning/brain dependency to OSS compatible test_util This is one test file failing, due to the monkey patching happens in the dtensor.init(), and I will need to dig more about the root cause (probably due to patching tf.Variable with DVariable, and cause logic difference for instance type checking.) PiperOrigin-RevId: 439676157 * Update the docstring for keras.dtensor components. 1. Add docstring for LayoutMap. 2. Hide certain methods for keras.dtensor.optimizers. PiperOrigin-RevId: 442651549
1 parent 27e3966 commit 970c060

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

keras/dtensor/layout_map.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,41 @@ def get_current_layout_map():
4646

4747
@keras_export('keras.dtensor.experimental.LayoutMap', v1=[])
4848
class LayoutMap(collections.abc.MutableMapping):
49+
"""A dict-like object that maps string to `Layout` instances.
4950
50-
def __init__(self, mesh=None):
51-
"""A dict like object that maps between string name and dtensor.Layout.
51+
`LayoutMap` uses a string as key and a `Layout` as value. There is a behavior
52+
difference between a normal Python dict and this class. The string key will be
53+
treated as a regex when retrieving the value. See the docstring of
54+
`get` for more details.
5255
53-
Note that this class might behave differently than a normal dict, eg, it
54-
will treat the all the existing keys as a regex to map against input key.
56+
See below for a usage example. You can define the naming schema
57+
of the `Layout`, and then retrieve the corresponding `Layout` instance.
5558
56-
Args:
57-
mesh: An optional dtensor.Mesh that is used to provide all replicated
58-
layout as default when there isn't a layout is found based on the
59-
mapping.
60-
"""
59+
To use the `LayoutMap` with a `Model`, please see the docstring of
60+
`tf.keras.dtensor.experimental.layout_map_scope`.
61+
62+
```python
63+
map = LayoutMap(mesh=None)
64+
map['.*dense.*kernel'] = layout_2d
65+
map['.*dense.*bias'] = layout_1d
66+
map['.*conv2d.*kernel'] = layout_4d
67+
map['.*conv2d.*bias'] = layout_1d
68+
69+
layout_1 = map['dense_1.kernel'] # layout_1 == layout_2d
70+
layout_2 = map['dense_1.bias'] # layout_2 == layout_1d
71+
layout_3 = map['dense_2.kernel'] # layout_3 == layout_2d
72+
layout_4 = map['dense_2.bias'] # layout_4 == layout_1d
73+
layout_5 = map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
74+
layout_6 = map['my_model/conv2d_123/bias'] # layout_6 == layout_1d
75+
```
76+
77+
Args:
78+
mesh: An optional `Mesh` that can be used to create all replicated
79+
layout as default when there isn't a layout found based on the input
80+
string query.
81+
"""
82+
83+
def __init__(self, mesh=None):
6184
self._layout_map = collections.OrderedDict()
6285
self._default_mesh = mesh
6386

@@ -105,9 +128,17 @@ def __iter__(self):
105128
return iter(self._layout_map)
106129

107130
def get_default_mesh(self):
131+
"""Return the default `Mesh` set at instance creation.
132+
133+
The `Mesh` can be used to create default replicated `Layout` when there
134+
isn't a match of the input string query.
135+
"""
108136
return self._default_mesh
109137

110138

139+
LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__
140+
141+
111142
@keras_export('keras.dtensor.experimental.layout_map_scope', v1=[])
112143
@contextlib.contextmanager
113144
def layout_map_scope(layout_map):

keras/dtensor/optimizers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import tensorflow.compat.v2 as tf
2727

2828
from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-direct-tensorflow-import
29+
from tensorflow.tools.docs import doc_controls
2930

3031

3132
# pylint: disable=protected-access,missing-class-docstring
@@ -103,6 +104,7 @@ def add_variable_from_reference(self,
103104
dtype=model_variable.dtype,
104105
trainable=False)
105106

107+
@doc_controls.do_not_generate_docs
106108
def aggregate_gradients(self, grads_and_vars):
107109
# Hide the aggregate_gradients from Optimizer.aggregate_gradients
108110
raise NotImplementedError(

0 commit comments

Comments
 (0)