|
|
|
@ -12,8 +12,8 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
import numpy as np
|
|
|
|
|
import sys
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
@ -31,8 +31,8 @@ SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path)
|
|
|
|
|
|
|
|
|
|
def test_me_de_train_dataset():
|
|
|
|
|
data_list = ["{0}/train-00001-of-01024.data".format(data_path)]
|
|
|
|
|
data_set = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR,
|
|
|
|
|
columns_list=["image/encoded", "image/class/label"])
|
|
|
|
|
data_set_new = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR,
|
|
|
|
|
columns_list=["image/encoded", "image/class/label"])
|
|
|
|
|
|
|
|
|
|
resize_height = 224
|
|
|
|
|
resize_width = 224
|
|
|
|
@ -42,21 +42,21 @@ def test_me_de_train_dataset():
|
|
|
|
|
# define map operations
|
|
|
|
|
|
|
|
|
|
decode_op = vision.Decode()
|
|
|
|
|
resize_op = vision.Resize(resize_height, resize_width,
|
|
|
|
|
resize_op = vision.Resize((resize_height, resize_width),
|
|
|
|
|
Inter.LINEAR) # Bilinear as default
|
|
|
|
|
rescale_op = vision.Rescale(rescale, shift)
|
|
|
|
|
|
|
|
|
|
# apply map operations on images
|
|
|
|
|
data_set = data_set.map(input_columns="image/encoded", operations=decode_op)
|
|
|
|
|
data_set = data_set.map(input_columns="image/encoded", operations=resize_op)
|
|
|
|
|
data_set = data_set.map(input_columns="image/encoded", operations=rescale_op)
|
|
|
|
|
data_set_new = data_set_new.map(input_columns="image/encoded", operations=decode_op)
|
|
|
|
|
data_set_new = data_set_new.map(input_columns="image/encoded", operations=resize_op)
|
|
|
|
|
data_set_new = data_set_new.map(input_columns="image/encoded", operations=rescale_op)
|
|
|
|
|
hwc2chw_op = vision.HWC2CHW()
|
|
|
|
|
data_set = data_set.map(input_columns="image/encoded", operations=hwc2chw_op)
|
|
|
|
|
data_set = data_set.repeat(1)
|
|
|
|
|
data_set_new = data_set_new.map(input_columns="image/encoded", operations=hwc2chw_op)
|
|
|
|
|
data_set_new = data_set_new.repeat(1)
|
|
|
|
|
# apply batch operations
|
|
|
|
|
batch_size = 32
|
|
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
|
|
|
|
return data_set
|
|
|
|
|
batch_size_new = 32
|
|
|
|
|
data_set_new = data_set_new.batch(batch_size_new, drop_remainder=True)
|
|
|
|
|
return data_set_new
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_type(shapes, types):
|
|
|
|
|