-
Notifications
You must be signed in to change notification settings - Fork 372
Description
🐛 Bug
The __len__()
method of a BatchSplittingSampler
that wraps a DPDataLoader
is meant to return the number of physical (as opposed to logical) batches in its iterator. Because Poisson sampling produces variable length logical batches, this length is necessarily approximate and will vary between runs. However, the approximation implemented in __len__()
is inaccurate:
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))
The actual expected number of physical batches per logical batch is:
where self.max_batch_size
and self.sampler.num_samples
trials and self.sampler.sample_rate
success probability.
This can be approximated as e.g.,
from scipy.stats import binom
def F(k):
return binom(self.sampler.num_samples, self.sampler.sample_rate).cdf(k * self.max_batch_size) - \
binom(self.sampler.num_samples, self.sampler.sample_rate).cdf((k - 1) * self.max_batch_size)
expected_physical_batches = int(self.sampler.num_samples * self.sampler.sample_rate / self.max_batch_size)
return int(
len(self.sampler) *
sum([i * F(i) for i in range(expected_physical_batches - 4, expected_physical_batches + 4)])
)
Please reproduce using our template Colab and post here the link
Here's a notebook built from the Colab template showing the discrepancy between computed and actual lengths:
https://gist.github.com/s-zanella/b70308db3d6d1b1bf15a5a2c8a1cc525
Expected behavior
It's unclear what is the desired behavior. The length approximation currently implemented is clearly incorrect, but a better approximation doesn't help much because the length of a BatchSplittingSampler
with Poisson sampling is not fixed. It would be nice to at least warn that the returned length is approximative.
From a user point of view, if BatchMemoryManager
is to be a transparent abstraction, I do not care so much about the number of physical batches processed, but about the number of logical batches. The current abstraction does not signal the beginning/end of logical batches, which makes it hard (impossible without code introspection?) to keep track of the number of logical batches processed so far. Having a mechanism to signal the beginning/end of a logical a batch would solve this issue.