How To Sample Batch From Only One Class At Each Iteration

I'd like to train a classifier on one ImageNet dataset (1000 classes each with around 1300 images). For some reason, I need each batch to contain 64 images from the same class, and

Solution 1:

I don't think your solution could work, if I understand it correctly, because sample_from_dataset expects a list of values for its weights, not a Tensor.

However if you don't mind having 1000 Datasets as in your proposed solution, then I would suggest to simply

  • create one Dataset per class,
  • batch each of these datasets — each batch has samples from a single class,
  • zip all of them into one big Dataset of batches,
  • shuffle this Dataset — the shuffling will occur on the batches, not on the samples, so it won't change the fact that batches are single class.

A more sophisticated way is to rely on Let me illustrate that with a synthetic example.

import numpy as np
import tensorflow as tf

def gen():
  while True:
    x = np.random.normal()
    label = np.random.randint(10)
    yield x, label

batch_size = 4
batch = (
  .from_generator(gen, (tf.float32, tf.int64), (tf.TensorShape([]), tf.TensorShape([])))
    key_func=lambda x, label: label,
    reduce_func=lambda key, d: d.batch(batch_size),

sess = tf.InteractiveSession()
# (array([ 0.04058843,  0.2843775 , -1.8626076 ,  1.1154234 ], dtype=float32),
# array([6, 6, 6, 6], dtype=int64))
# (array([ 1.3600663,  0.5935658, -0.6740045,  1.174328 ], dtype=float32),
# array([3, 3, 3, 3], dtype=int64))

Solution 2:

P-Gn's solution to create separate datasets for each class is probably optimal. However, this can be avoided as follows

# Init some dataset
num_classes = 10
label = tf.range(num_classes, dtype=tf.int32)
features = tf.cast(label * 10, dtype=tf.float32) + tf.random.uniform(shape=[tf.shape(label)[0]], maxval=0.01)
dataset ={'label':label, 'features': features})
dataset = dataset.repeat()

# Split to buckets by label
batch_size = 4
dataset = dataset.apply(
    element_length_func=lambda s: s['label'],
    bucket_boundaries=list(range(1, num_classes)),
    bucket_batch_sizes=[batch_size] * num_classes,

# Show result
iterator = dataset.as_numpy_iterator()
for i inrange(5):

# {'label': array([0, 0, 0, 0], dtype=int32), 'features': array([0.00370963, 0.00370963, 0.00370963, 0.00370963], dtype=float32)}# {'label': array([1, 1, 1, 1], dtype=int32), 'features': array([10.009371, 10.009371, 10.009371, 10.009371], dtype=float32)}# {'label': array([2, 2, 2, 2], dtype=int32), 'features': array([20.001854, 20.001854, 20.001854, 20.001854], dtype=float32)}# {'label': array([3, 3, 3, 3], dtype=int32), 'features': array([30.005934, 30.005934, 30.005934, 30.005934], dtype=float32)}# {'label': array([4, 4, 4, 4], dtype=int32), 'features': array([40.001686, 40.001686, 40.001686, 40.001686], dtype=float32)}

