Skip to content

Length of BatchSplittingSampler with Poisson sampling #516

@s-zanella

Description

@s-zanella

🐛 Bug

The __len__() method of a BatchSplittingSampler that wraps a DPDataLoader is meant to return the number of physical (as opposed to logical) batches in its iterator. Because Poisson sampling produces variable length logical batches, this length is necessarily approximate and will vary between runs. However, the approximation implemented in __len__() is inaccurate:

expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))

The actual expected number of physical batches per logical batch is:

$$ \sum_{k=1}^\infty k\ \left( F(k\ m) - F((k - 1)\ m) \right) $$

where $m$ is the maximum physical batch size self.max_batch_size and $F$ is the CDF of the binomial distribution with self.sampler.num_samples trials and self.sampler.sample_rate success probability.

This can be approximated as e.g.,

from scipy.stats import binom

def F(k):
    return binom(self.sampler.num_samples, self.sampler.sample_rate).cdf(k * self.max_batch_size) - \
           binom(self.sampler.num_samples, self.sampler.sample_rate).cdf((k - 1) * self.max_batch_size)

expected_physical_batches = int(self.sampler.num_samples * self.sampler.sample_rate / self.max_batch_size)

return int(
    len(self.sampler) *
    sum([i * F(i) for i in range(expected_physical_batches - 4, expected_physical_batches + 4)])
)

Please reproduce using our template Colab and post here the link

Here's a notebook built from the Colab template showing the discrepancy between computed and actual lengths:
https://gist.github.com/s-zanella/b70308db3d6d1b1bf15a5a2c8a1cc525

Expected behavior

It's unclear what is the desired behavior. The length approximation currently implemented is clearly incorrect, but a better approximation doesn't help much because the length of a BatchSplittingSampler with Poisson sampling is not fixed. It would be nice to at least warn that the returned length is approximative.

From a user point of view, if BatchMemoryManager is to be a transparent abstraction, I do not care so much about the number of physical batches processed, but about the number of logical batches. The current abstraction does not signal the beginning/end of logical batches, which makes it hard (impossible without code introspection?) to keep track of the number of logical batches processed so far. Having a mechanism to signal the beginning/end of a logical a batch would solve this issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions