-
Notifications
You must be signed in to change notification settings - Fork 665
Description
Hi,
I was looking at the implementation of MPerClassSampler, and I noticed the following issue: in consecutive batches, there are often overlaps of classes used. For example, the first batch with batch_size=16, and m=4, might consist of classes: [1,5,3,7], while the second one might be [1,9,8,2]. This would mean that examples from class 1 could be seen more often than other examples with small datasets.
I think this can be easily overcome by generating ((length_before_new_iter // batch_size) * m) // num_unique_labels + 1 arrays of unique labels, shuffling each of them and then concatenating them. This way the sampler can take labels from i*m to (i+1)*m and be certain that after the epoch, examples from a certain class have been seen either (length_before_new_iter // batch_size) * batch_size // num_unique_labels or ((length_before_new_iter // batch_size) * batch_size // num_unique_labels) + 1 times, minimizing the initial issue.
I'm pretty certain the difference in performance would be minimal, if any. Does this make sense?