@@ -119,15 +119,15 @@ def _check_input_shape(self, input: Tensor):
119
119
class ConditionalDensityEstimator (ConditionalEstimator ):
120
120
r"""Base class for density estimators.
121
121
122
- The density estimator class is a wrapper around neural networks that
123
- allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$
124
- pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
122
+ The density estimator class is a wrapper around neural networks that allows to
123
+ evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$ pairs. Here
124
+ $\theta$ would be the `input` and $x$ would be the `condition`.
125
125
126
126
Note:
127
127
We assume that the input to the density estimator is a tensor of shape
128
- (batch_size, input_size ), where input_size is the dimensionality of the input.
129
- The condition is a tensor of shape (batch_size, *condition_shape), where
130
- condition_shape is the shape of the condition tensor.
128
+ (sample_dim, batch_dim, *input_shape ), where input_shape is the dimensionality
129
+ of the input. The condition is a tensor of shape (batch_size, *condition_shape),
130
+ where condition_shape is the shape of the condition tensor.
131
131
132
132
"""
133
133
@@ -226,15 +226,15 @@ def sample_and_log_prob(
226
226
class ConditionalVectorFieldEstimator (ConditionalEstimator ):
227
227
r"""Base class for vector field (e.g., score and ODE flow) estimators.
228
228
229
- The density estimator class is a wrapper around neural networks that
230
- allows to evaluate the `vector_field`, and provide the `loss` of $\theta,x$
231
- pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
229
+ The vector field estimator class is a wrapper around neural networks that allows to
230
+ evaluate the `vector_field`, and provide the `loss` of $\theta,x$ pairs. Here
231
+ $\theta$ would be the `input` and $x$ would be the `condition`.
232
232
233
233
Note:
234
234
We assume that the input to the density estimator is a tensor of shape
235
- (batch_size, input_size ), where input_size is the dimensionality of the input.
236
- The condition is a tensor of shape (batch_size , *condition_shape), where
237
- condition_shape is the shape of the condition tensor.
235
+ (sample_dim, batch_dim, *input_shape ), where input_shape is the dimensionality
236
+ of the input. The condition is a tensor of shape (batch_dim , *condition_shape),
237
+ where condition_shape is the shape of the condition tensor.
238
238
239
239
"""
240
240
0 commit comments