|
|
|
@ -302,6 +302,19 @@ class Dataset:
|
|
|
|
|
>>> # Create a dataset where every 100 rows is combined into a batch
|
|
|
|
|
>>> # and drops the last incomplete batch if there is one.
|
|
|
|
|
>>> data = data.batch(100, True)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
|
|
|
|
|
>>> def np_resize(col, batchInfo):
|
|
|
|
|
>>> output = col.copy()
|
|
|
|
|
>>> s = (batchInfo.get_batch_num() + 1) ** 2
|
|
|
|
|
>>> index = 0
|
|
|
|
|
>>> for c in col:
|
|
|
|
|
>>> img = Image.fromarray(c.astype('uint8')).convert('RGB')
|
|
|
|
|
>>> img = img.resize((s, s), Image.ANTIALIAS)
|
|
|
|
|
>>> output[index] = np.array(img)
|
|
|
|
|
>>> index += 1
|
|
|
|
|
>>> return (output,)
|
|
|
|
|
>>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
|
|
|
|
|
"""
|
|
|
|
|
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
|
|
|
|
|
output_columns, column_order, pad_info)
|
|
|
|
|