You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
4.0 KiB
114 lines
4.0 KiB
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
import sys
|
|
import numpy as np
|
|
|
|
import mindspore.context as context
|
|
import mindspore.dataset as ds
|
|
import mindspore.dataset.vision.c_transforms as vision
|
|
import mindspore.nn as nn
|
|
from mindspore.common.api import _executor
|
|
from mindspore.common.tensor import Tensor
|
|
from mindspore.dataset.vision import Inter
|
|
from mindspore.ops import operations as P
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
data_path = sys.argv[1]
|
|
SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path)
|
|
|
|
|
|
def test_me_de_train_dataset():
|
|
data_list = ["{0}/train-00001-of-01024.data".format(data_path)]
|
|
data_set_new = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR,
|
|
columns_list=["image/encoded", "image/class/label"])
|
|
|
|
resize_height = 224
|
|
resize_width = 224
|
|
rescale = 1.0 / 255.0
|
|
shift = 0.0
|
|
|
|
# define map operations
|
|
|
|
decode_op = vision.Decode()
|
|
resize_op = vision.Resize((resize_height, resize_width),
|
|
Inter.LINEAR) # Bilinear as default
|
|
rescale_op = vision.Rescale(rescale, shift)
|
|
|
|
# apply map operations on images
|
|
data_set_new = data_set_new.map(operations=decode_op, input_columns="image/encoded")
|
|
data_set_new = data_set_new.map(operations=resize_op, input_columns="image/encoded")
|
|
data_set_new = data_set_new.map(operations=rescale_op, input_columns="image/encoded")
|
|
hwc2chw_op = vision.HWC2CHW()
|
|
data_set_new = data_set_new.map(operations=hwc2chw_op, input_columns="image/encoded")
|
|
data_set_new = data_set_new.repeat(1)
|
|
# apply batch operations
|
|
batch_size_new = 32
|
|
data_set_new = data_set_new.batch(batch_size_new, drop_remainder=True)
|
|
return data_set_new
|
|
|
|
|
|
def convert_type(shapes, types):
|
|
ms_types = []
|
|
for np_shape, np_type in zip(shapes, types):
|
|
input_np = np.zeros(np_shape, np_type)
|
|
tensor = Tensor(input_np)
|
|
ms_types.append(tensor.dtype)
|
|
return ms_types
|
|
|
|
|
|
if __name__ == '__main__':
|
|
data_set = test_me_de_train_dataset()
|
|
dataset_size = data_set.get_dataset_size()
|
|
batch_size = data_set.get_batch_size()
|
|
|
|
dataset_shapes = data_set.output_shapes()
|
|
np_types = data_set.output_types()
|
|
dataset_types = convert_type(dataset_shapes, np_types)
|
|
|
|
ds1 = data_set.device_que()
|
|
get_next = P.GetNext(dataset_types, dataset_shapes, 2, ds1.queue_name)
|
|
tadd = P.ReLU()
|
|
|
|
|
|
class dataiter(nn.Cell):
|
|
|
|
def construct(self):
|
|
input_, _ = get_next()
|
|
return tadd(input_)
|
|
|
|
|
|
net = dataiter()
|
|
net.set_train()
|
|
|
|
_executor.init_dataset(ds1.queue_name, 39, batch_size,
|
|
dataset_types, dataset_shapes, (), 'dataset')
|
|
ds1.send()
|
|
|
|
for data in data_set.create_tuple_iterator(output_numpy=True, num_epochs=1):
|
|
output = net()
|
|
print(data[0].any())
|
|
print(
|
|
"****************************************************************************************************")
|
|
d = output.asnumpy()
|
|
print(d)
|
|
print(
|
|
"end+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++",
|
|
d.any())
|
|
|
|
assert (
|
|
(data[0] == d).all()), "TDT test execute failed, please check current code commit"
|
|
print(
|
|
"+++++++++++++++++++++++++++++++++++[INFO] Success+++++++++++++++++++++++++++++++++++++++++++")
|