Skip to content

Commit 8a7d8c4

Browse files
committed
autoencoder with mask. documentation added.
1 parent edcd47c commit 8a7d8c4

File tree

1 file changed

+165
-4
lines changed

1 file changed

+165
-4
lines changed

src/lasdi/latent_space.py

Lines changed: 165 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,23 @@
2828
}
2929

3030
def initial_condition_latent(param_grid, physics, autoencoder):
31-
3231
'''
32+
Outputs the initial condition in the latent space: Z0 = encoder(U0)
3333
34-
Outputs the initial condition in the latent space: Z0 = encoder(U0)
34+
Arguments
35+
---------
36+
param_grid : :obj:`numpy.array`
37+
A 2d array of shape `(n_param, param_dim)` for parameter points to obtain initial condition.
38+
physics : :obj:`lasdi.physics.Physics`
39+
Physics class to generate initial condition.
40+
autoencoder : :obj:`lasdi.latent_space.Autoencoder`
41+
Autoencoder class to encode initial conditions into latent variables.
3542
43+
Returns
44+
-------
45+
Z0 : :obj:`torch.Tensor`
46+
a torch tensor of size `(n_param, n_z)`, where `n_z` is the latent variable dimension
47+
defined by `autoencoder`.
3648
'''
3749

3850
n_param = param_grid.shape[0]
@@ -51,18 +63,22 @@ def initial_condition_latent(param_grid, physics, autoencoder):
5163
return Z0
5264

5365
class MultiLayerPerceptron(torch.nn.Module):
66+
"""A standard multi-layer perceptron (MLP) module."""
5467

5568
def __init__(self, layer_sizes,
5669
act_type='sigmoid', reshape_index=None, reshape_shape=None,
5770
threshold=0.1, value=0.0, num_heads=1):
5871
super(MultiLayerPerceptron, self).__init__()
5972

60-
# including input, hidden, output layers
6173
self.n_layers = len(layer_sizes)
74+
""":obj:`int`: Depth of MLP including input, hidden, and output layers."""
6275
self.layer_sizes = layer_sizes
76+
""":obj:`list(int)`: Widths of each MLP layer, including input, hidden and output layers."""
6377

64-
# Linear features between layers
6578
self.fcs = []
79+
""":obj:`torch.nn.ModuleList`: torch module list of :math:`(self.n\_layers-1)` linear layers,
80+
connecting from input to output layers.
81+
"""
6682
for k in range(self.n_layers-1):
6783
self.fcs += [torch.nn.Linear(layer_sizes[k], layer_sizes[k + 1])]
6884
self.fcs = torch.nn.ModuleList(self.fcs)
@@ -71,14 +87,35 @@ def __init__(self, layer_sizes,
7187
# Reshape input or output layer
7288
assert((reshape_index is None) or (reshape_index in [0, -1]))
7389
assert((reshape_shape is None) or (np.prod(reshape_shape) == layer_sizes[reshape_index]))
90+
7491
self.reshape_index = reshape_index
92+
""":obj:`int`: Index of the layer to reshape.
93+
94+
* 0: Input data is n-dimensional and will be squeezed into 1d tensor for MLP input.
95+
* 1: Output data should be n-dimensional and MLP output will be reshaped as such.
96+
"""
7597
self.reshape_shape = reshape_shape
98+
""":obj:`list(int)`: Shape of the layer to be reshaped.
99+
100+
* :math:`(self.reshape_index=0)`: Shape of the input data that will be squeezed into 1d tensor for MLP input.
101+
* :math:`(self.reshape_index=1)`: Shape of the output data into which MLP output shall be reshaped.
102+
"""
76103

77104
# Initalize activation function
78105
self.act_type = act_type
106+
""":obj:`str`: type of activation function"""
79107
self.use_multihead = False
108+
""":obj:`bool`: switch to use multihead attention.
109+
110+
Warning:
111+
this attribute is obsolete and will be removed in future.
112+
"""
113+
114+
self.act = None
115+
""":obj:`torch.nn.Module`: activation function"""
80116
if act_type == "threshold":
81117
self.act = act_dict[act_type](threshold, value)
118+
82119

83120
elif act_type == "multihead":
84121
self.use_multihead = True
@@ -96,6 +133,20 @@ def __init__(self, layer_sizes,
96133
return
97134

98135
def forward(self, x):
136+
"""Pass the input through the MLP layers.
137+
138+
Args:
139+
x (:obj:`torch.Tensor`): n-dimensional torch.Tensor for input data.
140+
141+
Note:
142+
* If :obj:`self.reshape_index == 0`, then the last n dimensions of :obj:`x` must match :obj:`self.reshape_shape`. In other words, :obj:`list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape`
143+
* If :obj:`self.reshape_index == -1`, then the last layer output :obj:`z` is reshaped into :obj:`self.reshape_shape`. In other words, :obj:`list(z.shape[-len(self.reshape_shape):]) == self.reshape_shape`
144+
145+
Returns:
146+
n-dimensional torch.Tensor for output data.
147+
148+
"""
149+
99150
if (self.reshape_index == 0):
100151
# make sure the input has a proper shape
101152
assert(list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape)
@@ -126,12 +177,28 @@ def apply_attention(self, x, act_idx):
126177
return x
127178

128179
def init_weight(self):
180+
"""Initialize the weights and biases of the linear layers.
181+
182+
Returns:
183+
Does not return a value.
184+
185+
"""
129186
# TODO(kevin): support other initializations?
130187
for fc in self.fcs:
131188
torch.nn.init.xavier_uniform_(fc.weight)
132189
return
133190

134191
class Autoencoder(torch.nn.Module):
192+
"""A standard autoencoder using MLP.
193+
194+
Args:
195+
physics (:obj:`lasdi.physics.Physics`): Physics class that specifies full-order model solution dimensions.
196+
197+
config: (:obj:`dict`): options for autoencoder. It must include the following keys and values.
198+
* :obj:`'hidden_units'`: a list of integers for the widths of hidden layers.
199+
* :obj:`'latent_dimension'`: integer for the latent space dimension.
200+
* :obj:`'activation'`: string for type of activation function.
201+
"""
135202

136203
def __init__(self, physics, config):
137204
super(Autoencoder, self).__init__()
@@ -172,4 +239,98 @@ def export(self):
172239

173240
def load(self, dict_):
174241
self.load_state_dict(dict_['autoencoder_param'])
242+
return
243+
244+
class MLPWithMask(MultiLayerPerceptron):
245+
"""Multi-layer perceptron with additional mask output.
246+
247+
Args:
248+
mlp (:obj:`lasdi.latent_space.MultiLayerPerceptron`): MultiLayerPerceptron class to copy.
249+
The same architecture, activation function, reshaping will be used.
250+
251+
"""
252+
253+
def __init__(self, mlp):
254+
assert(isinstance(mlp, MultiLayerPerceptron))
255+
from copy import deepcopy
256+
torch.nn.Module.__init__(self)
257+
258+
# including input, hidden, output layers
259+
self.n_layers = mlp.n_layers
260+
self.layer_sizes = deepcopy(mlp.layer_sizes)
261+
262+
# Linear features between layers
263+
self.fcs = deepcopy(mlp.fcs)
264+
265+
# Reshape input or output layer
266+
self.reshape_index = deepcopy(mlp.reshape_index)
267+
self.reshape_shape = deepcopy(mlp.reshape_shape)
268+
269+
# Initalize activation function
270+
self.act_type = mlp.act_type
271+
self.use_multihead = mlp.use_multihead
272+
self.act = deepcopy(mlp.act)
273+
274+
self.bool_d = torch.nn.Linear(self.layer_sizes[-2], self.layer_sizes[-1])
275+
""":obj:`torch.nn.Linear`: additional linear layer to output a mask variable."""
276+
torch.nn.init.xavier_uniform_(self.bool_d.weight)
277+
278+
self.sigmoid = torch.nn.Sigmoid()
279+
""":obj:`torch.nn.Sigmoid`: mask output passes through the sigmoid activation function to ensure :math:`[0, 1]`."""
280+
return
281+
282+
def forward(self, x):
283+
"""Pass the input through the MLP layers.
284+
285+
Args:
286+
x (:obj:`torch.Tensor`): n-dimensional torch.Tensor for input data.
287+
288+
Note:
289+
* If :obj:`self.reshape_index == 0`, then the last n dimensions of :obj:`x` must match :obj:`self.reshape_shape`. In other words, :obj:`list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape`
290+
* If :obj:`self.reshape_index == -1`, then the last layer outputs :obj:`xval` and :obj:`xbool` are reshaped into :obj:`self.reshape_shape`. In other words, :obj:`list(xval.shape[-len(self.reshape_shape):]) == self.reshape_shape`
291+
292+
Returns:
293+
xval (:obj:`torch.Tensor`): n-dimensional torch.Tensor for output data.
294+
xbool (:obj:`torch.Tensor`): n-dimensional torch.Tensor for output mask.
295+
296+
"""
297+
if (self.reshape_index == 0):
298+
# make sure the input has a proper shape
299+
assert(list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape)
300+
# we use torch.Tensor.view instead of torch.Tensor.reshape,
301+
# in order to avoid data copying.
302+
x = x.view(list(x.shape[:-len(self.reshape_shape)]) + [self.layer_sizes[self.reshape_index]])
303+
304+
for i in range(self.n_layers-2):
305+
x = self.fcs[i](x) # apply linear layer
306+
if (self.use_multihead):
307+
x = self.apply_attention(self, x, i)
308+
else:
309+
x = self.act(x)
310+
311+
xval, xbool = self.fcs[-1](x), self.bool_d(x)
312+
xbool = self.sigmoid(xbool)
313+
314+
if (self.reshape_index == -1):
315+
# we use torch.Tensor.view instead of torch.Tensor.reshape,
316+
# in order to avoid data copying.
317+
xval = xval.view(list(x.shape[:-1]) + self.reshape_shape)
318+
xbool = xbool.view(list(x.shape[:-1]) + self.reshape_shape)
319+
320+
return xval, xbool
321+
322+
class AutoEncoderWithMask(Autoencoder):
323+
"""Autoencoder class with additional mask output.
324+
325+
Its decoder is :obj:`lasdi.latent_space.MLPWithMask`,
326+
which has an additional mask output.
327+
328+
Note:
329+
Unlike the standard autoencoder, the decoder output will have two outputs (with the same shape of the input of the encoder).
330+
"""
331+
332+
def __init__(self, physics, config):
333+
Autoencoder.__init__(self, physics, config)
334+
335+
self.decoder = MLPWithMask(self.decoder)
175336
return

0 commit comments

Comments
 (0)