add per_batch_map usage in api comment

pull/7618/head
jonyguo 4 years ago
parent 45913d0682
commit f258697901

@ -302,6 +302,19 @@ class Dataset:
>>> # Create a dataset where every 100 rows is combined into a batch
>>> # and drops the last incomplete batch if there is one.
>>> data = data.batch(100, True)
>>>
>>> # resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
>>> def np_resize(col, batchInfo):
>>> output = col.copy()
>>> s = (batchInfo.get_batch_num() + 1) ** 2
>>> index = 0
>>> for c in col:
>>> img = Image.fromarray(c.astype('uint8')).convert('RGB')
>>> img = img.resize((s, s), Image.ANTIALIAS)
>>> output[index] = np.array(img)
>>> index += 1
>>> return (output,)
>>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
"""
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
output_columns, column_order, pad_info)

Loading…
Cancel
Save