-
Notifications
You must be signed in to change notification settings - Fork 90
Decouple batch size and number of negatives #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
723bcfb
dbabb6e
540b006
07212f2
6c2d559
0dba5fc
e259e45
f2af3b6
1412484
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| """Base classes for datasets and loaders.""" | ||
|
|
||
| import abc | ||
| from typing import Iterator | ||
|
|
||
| import literate_dataclasses as dataclasses | ||
| import torch | ||
|
|
@@ -239,6 +240,12 @@ class Loader(abc.ABC, cebra.io.HasDevice): | |
| batch_size: int = dataclasses.field(default=None, | ||
| doc="""The total batch size.""") | ||
|
|
||
| num_negatives: int = dataclasses.field( | ||
| default=None, | ||
| doc=("The number of negative samples to draw for each reference. " | ||
| "If not specified, the batch size is used."), | ||
| ) | ||
|
|
||
| def __post_init__(self): | ||
| if self.num_steps is None or self.num_steps <= 0: | ||
| raise ValueError( | ||
|
|
@@ -248,28 +255,41 @@ def __post_init__(self): | |
| raise ValueError( | ||
| f"Batch size has to be None, or a non-negative value. Got {self.batch_size}." | ||
| ) | ||
| if self.num_negatives is not None and self.num_negatives <= 0: | ||
| raise ValueError( | ||
| f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}." | ||
| ) | ||
|
|
||
| if self.num_negatives is None: | ||
| self.num_negatives = self.batch_size | ||
|
|
||
| def __len__(self): | ||
| """The number of batches returned when calling as an iterator.""" | ||
| return self.num_steps | ||
|
|
||
| def __iter__(self) -> Batch: | ||
| def __iter__(self) -> Iterator[Batch]: | ||
| for _ in range(len(self)): | ||
| index = self.get_indices(num_samples=self.batch_size) | ||
| index = self.get_indices() | ||
| yield self.dataset.load_batch(index) | ||
|
|
||
| @abc.abstractmethod | ||
| def get_indices(self, num_samples: int): | ||
| def get_indices(self, num_samples: int = None): | ||
| """Sample and return the specified number of indices. | ||
|
|
||
| The elements of the returned `BatchIndex` will be used to index the | ||
| `dataset` of this data loader. | ||
|
|
||
| Args: | ||
| num_samples: The size of each of the reference, positive and | ||
| negative samples. | ||
| num_samples: Deprecated. Use ``batch_size`` on the instance level | ||
| instead. | ||
|
|
||
| Returns: | ||
| batch indices for the reference, positive and negative sample. | ||
|
|
||
| Note: | ||
| From version 0.7.0 onwards, specifying the ``num_samples`` | ||
| directly is deprecated and will be removed in version 0.8.0. | ||
|
||
| Please set ``batch_size`` and ``num_negatives`` on the instance | ||
| level instead. | ||
| """ | ||
| raise NotImplementedError() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| import abc | ||||||||||||||||||
| import warnings | ||||||||||||||||||
| from typing import Iterator | ||||||||||||||||||
|
|
||||||||||||||||||
| import literate_dataclasses as dataclasses | ||||||||||||||||||
| import torch | ||||||||||||||||||
|
|
@@ -138,7 +139,7 @@ def _init_distribution(self): | |||||||||||||||||
| f"Invalid choice of prior distribution. Got '{self.prior}', but " | ||||||||||||||||||
| f"only accept 'uniform' or 'empirical' as potential values.") | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_indices(self, num_samples: int) -> BatchIndex: | ||||||||||||||||||
| def get_indices(self) -> BatchIndex: | ||||||||||||||||||
| """Samples indices for reference, positive and negative examples. | ||||||||||||||||||
|
|
||||||||||||||||||
| The reference samples will be sampled from the empirical or uniform prior | ||||||||||||||||||
|
|
@@ -154,13 +155,15 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||||||||||||||||
| Args: | ||||||||||||||||||
| num_samples: The number of samples (batch size) of the returned | ||||||||||||||||||
| :py:class:`cebra.data.datatypes.BatchIndex`. | ||||||||||||||||||
| num_negatives: The number of negative samples. If None, defaults to num_samples. | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns: | ||||||||||||||||||
| Indices for reference, positive and negatives samples. | ||||||||||||||||||
|
||||||||||||||||||
| Indices for reference, positive and negatives samples. | |
| The number of reference samples (batch size) and the number of negative samples | |
| are determined by the instance attributes ``batch_size`` and ``num_negatives``, respectively. | |
| Returns: | |
| Indices for reference, positive and negative samples. |
Copilot
AI
Aug 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, the docstring references 'num_samples' which no longer exists. This should be updated to reflect the current implementation.
| Indices for reference, positive and negatives samples. | |
| The number of reference samples (batch size) is determined by the | |
| instance's ``batch_size`` attribute. The number of negative samples | |
| is determined by the instance's ``num_negatives`` attribute. | |
| Returns: | |
| Indices for reference, positive and negatives samples as a | |
| :py:class:`cebra.data.datatypes.BatchIndex`. |
Copilot
AI
Aug 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another instance where the docstring incorrectly references 'num_samples'. This should be updated to reflect that num_negatives is an instance attribute.
Copilot
AI
Aug 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition check for batch_size in FullDataLoader.__post_init__() is incorrect. Since FullDataLoader inherits from ContinuousDataLoader and the base Loader class sets batch_size = None by default when not specified, this check will always be true when batch_size is explicitly set to None in the constructor call. The check should be if self.batch_size is not None: to properly validate that batch_size was not set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
get_indicesmethod signature change introduces a potential breaking change by makingnum_samplesoptional with a default ofNone. While the deprecation is documented, the method should handle the case wherenum_samplesis passed but shouldn't be used, potentially issuing a deprecation warning to guide users toward the new API.