Skip to content

Commit 390a518

Browse files
authored
docs: Fix docstrings for condition sample_dim. (#1338)
1 parent 2ecfe21 commit 390a518

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

sbi/neural_nets/estimators/base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ def _check_input_shape(self, input: Tensor):
119119
class ConditionalDensityEstimator(ConditionalEstimator):
120120
r"""Base class for density estimators.
121121
122-
The density estimator class is a wrapper around neural networks that
123-
allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$
124-
pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
122+
The density estimator class is a wrapper around neural networks that allows to
123+
evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$ pairs. Here
124+
$\theta$ would be the `input` and $x$ would be the `condition`.
125125
126126
Note:
127127
We assume that the input to the density estimator is a tensor of shape
128-
(batch_size, input_size), where input_size is the dimensionality of the input.
129-
The condition is a tensor of shape (batch_size, *condition_shape), where
130-
condition_shape is the shape of the condition tensor.
128+
(sample_dim, batch_dim, *input_shape), where input_shape is the dimensionality
129+
of the input. The condition is a tensor of shape (batch_size, *condition_shape),
130+
where condition_shape is the shape of the condition tensor.
131131
132132
"""
133133

@@ -226,15 +226,15 @@ def sample_and_log_prob(
226226
class ConditionalVectorFieldEstimator(ConditionalEstimator):
227227
r"""Base class for vector field (e.g., score and ODE flow) estimators.
228228
229-
The density estimator class is a wrapper around neural networks that
230-
allows to evaluate the `vector_field`, and provide the `loss` of $\theta,x$
231-
pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
229+
The vector field estimator class is a wrapper around neural networks that allows to
230+
evaluate the `vector_field`, and provide the `loss` of $\theta,x$ pairs. Here
231+
$\theta$ would be the `input` and $x$ would be the `condition`.
232232
233233
Note:
234234
We assume that the input to the density estimator is a tensor of shape
235-
(batch_size, input_size), where input_size is the dimensionality of the input.
236-
The condition is a tensor of shape (batch_size, *condition_shape), where
237-
condition_shape is the shape of the condition tensor.
235+
(sample_dim, batch_dim, *input_shape), where input_shape is the dimensionality
236+
of the input. The condition is a tensor of shape (batch_dim, *condition_shape),
237+
where condition_shape is the shape of the condition tensor.
238238
239239
"""
240240

sbi/neural_nets/estimators/nflows_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
8181
Args:
8282
input: Inputs to evaluate the log probability on. Of shape
8383
`(sample_dim, batch_dim, *event_shape)`.
84-
condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`.
84+
condition: Conditions of shape `(batch_dim, *event_shape)`.
8585
8686
Raises:
8787
AssertionError: If `input_batch_dim != condition_batch_dim`.
@@ -126,7 +126,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
126126
127127
Args:
128128
sample_shape: Shape of the samples to return.
129-
condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`.
129+
condition: Conditions of shape `(batch_dim, *event_shape)`.
130130
131131
Returns:
132132
Samples of shape `(*sample_shape, condition_batch_dim)`.
@@ -147,7 +147,7 @@ def sample_and_log_prob(
147147
148148
Args:
149149
sample_shape: Shape of the samples to return.
150-
condition: Conditions of shape (sample_dim, batch_dim, *event_shape).
150+
condition: Conditions of shape (batch_dim, *event_shape).
151151
152152
Returns:
153153
Samples of shape `(*sample_shape, condition_batch_dim, *input_event_shape)`

sbi/neural_nets/estimators/zuko_flow.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
101101
Args:
102102
input: Inputs to evaluate the log probability on. Of shape
103103
`(sample_dim, batch_dim, *event_shape)`.
104-
# TODO: the docstring is not correct here. in the code it seems we
105-
do not have a sample_dim for the condition.
106-
condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`.
104+
condition: Conditions of shape `(batch_dim, *event_shape)`.
107105
108106
Raises:
109107
AssertionError: If `input_batch_dim != condition_batch_dim`.

0 commit comments

Comments
 (0)