@@ -261,27 +261,47 @@ class MixedDataLoader(cebra_data.Loader):
261261 1. Positive pairs always share their discrete variable.
262262 2. Positive pairs are drawn only based on their conditional,
263263 not discrete variable.
264+
265+ Args:
266+ conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
267+ time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
268+ positive_sampling (str): either "discrete_variable" (default) or "conditional"
269+ discrete_sampling_prior (str): either "empirical" (default) or "uniform"
264270 """
265271
266272 conditional : str = dataclasses .field (default = "time_delta" )
267273 time_offset : int = dataclasses .field (default = 10 )
274+ positive_sampling : str = dataclasses .field (default = "discrete_variable" )
275+ discrete_sampling_prior : str = dataclasses .field (default = "uniform" )
268276
269277 @property
270- def dindex (self ):
271- # TODO(stes) rename to discrete_index
278+ def discrete_index (self ):
272279 return self .dataset .discrete_index
273280
274281 @property
275- def cindex (self ):
276- # TODO(stes) rename to continuous_index
282+ def continuous_index (self ):
277283 return self .dataset .continuous_index
278284
279285 def __post_init__ (self ):
280286 super ().__post_init__ ()
281- self .distribution = cebra .distributions .MixedTimeDeltaDistribution (
282- discrete = self .dindex ,
283- continuous = self .cindex ,
284- time_delta = self .time_offset )
287+ if self .positive_sampling == "conditional" :
288+ self .distribution = cebra .distributions .MixedTimeDeltaDistribution (
289+ discrete = self .discrete_index ,
290+ continuous = self .continuous_index ,
291+ time_delta = self .time_offset )
292+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior == "empirical" :
293+ self .distribution = cebra .distributions .DiscreteEmpirical (self .discrete_index )
294+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior == "uniform" :
295+ self .distribution = cebra .distributions .DiscreteUniform (self .discrete_index )
296+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior not in ["empirical" , "uniform" ]:
297+ raise ValueError (
298+ f"Invalid choice of prior distribution. Got '{ self .discrete_sampling_prior } ', but "
299+ f"only accept 'uniform' or 'empirical' as potential values." )
300+ else :
301+ raise ValueError (
302+ f"Invalid positive sampling mode: "
303+ f"{ self .positive_sampling } valid options are "
304+ f"'conditional' or 'discrete_variable'." )
285305
286306 def get_indices (self , num_samples : int ) -> BatchIndex :
287307 """Samples indices for reference, positive and negative examples.
@@ -306,12 +326,23 @@ def get_indices(self, num_samples: int) -> BatchIndex:
306326 class.
307327 - Sample the negatives with matching discrete variable
308328 """
309- reference_idx = self .distribution .sample_prior (num_samples )
310- return BatchIndex (
311- reference = reference_idx ,
312- negative = self .distribution .sample_prior (num_samples ),
313- positive = self .distribution .sample_conditional (reference_idx ),
314- )
329+ if self .positive_sampling == "conditional" :
330+ reference_idx = self .distribution .sample_prior (num_samples )
331+ return BatchIndex (
332+ reference = reference_idx ,
333+ negative = self .distribution .sample_prior (num_samples ),
334+ positive = self .distribution .sample_conditional (reference_idx ),
335+ )
336+ else :
337+ # taken from the DiscreteDataLoader get_indices function
338+ reference_idx = self .distribution .sample_prior (num_samples * 2 )
339+ negative_idx = reference_idx [num_samples :]
340+ reference_idx = reference_idx [:num_samples ]
341+ reference = self .discrete_index [reference_idx ]
342+ positive_idx = self .distribution .sample_conditional (reference )
343+ return BatchIndex (reference = reference_idx ,
344+ positive = positive_idx ,
345+ negative = negative_idx )
315346
316347
317348@dataclasses .dataclass
0 commit comments