pull/13757/head
mwang 4 years ago
parent 0f79635dd7
commit fdb7bbf422

@ -72,8 +72,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
type_cast_op = C2.TypeCast(mstype.int32) type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=24)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=24)
# apply batch operations # apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True) data_set = data_set.batch(batch_size, drop_remainder=True)

Loading…
Cancel
Save