From 23201ae0c31eeaff47abc823ca7d18427014c017 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Fri, 5 Feb 2021 18:17:59 +0800 Subject: [PATCH] add data transfer st --- .../data_transfer/test_tdt_data_transfer.py | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 tests/st/data_transfer/test_tdt_data_transfer.py diff --git a/tests/st/data_transfer/test_tdt_data_transfer.py b/tests/st/data_transfer/test_tdt_data_transfer.py new file mode 100644 index 0000000000..8b50908df7 --- /dev/null +++ b/tests/st/data_transfer/test_tdt_data_transfer.py @@ -0,0 +1,166 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import time +import numpy as np +import pytest +from mindspore import context, nn, Tensor +from mindspore import log as logger +from mindspore.common.api import _executor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +import mindspore.dataset as de +from mindspore.dataset.vision import c_transforms as c_vision +from mindspore.dataset.transforms import c_transforms as c_trans + + +DATA_DIR = "/home/workspace/mindspore_dataset/cifar-10-verify-bin" + + +def dataset_cifar(dataset_path=None, batch_size=32, repeat_num=1, num_rows=9600, distribution_num=None, shard_id=None, + drop_remainder=True, usage=None, shuffle=False, num_workers=8, resize_size=32, pad_info=None): + if dataset_path is None: + dataset_path = DATA_DIR + + ds = de.Cifar10Dataset(dataset_path, num_samples=num_rows, num_shards=distribution_num, shard_id=shard_id, + shuffle=shuffle, usage=usage, num_parallel_workers=num_workers) + + typecast_op = c_trans.TypeCast(mstype.int32) + ds = ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=num_workers) + + image_op_list = [c_vision.Resize(resize_size), + c_vision.Rescale(1.0 / 255.0, 0.0), + c_vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + c_vision.HWC2CHW()] + ds = ds.map(input_columns="image", operations=image_op_list, num_parallel_workers=num_workers) + + ds = ds.batch(batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_workers, pad_info=pad_info) + ds = ds.repeat(repeat_num) + + return ds + + +def op_network_with_epoch(network, step_num): + iter_num = 0 + network.set_train() + for _ in range(step_num): + op_return = network() + op_return = op_return.asnumpy() + logger.info("Op_return is : %s", op_return) + iter_num += 1 + logger.info("Iter Num : %s", iter_num) + + return iter_num + + +def convert_type(shapes, types): + ms_types = [] + for np_shape, np_type in zip(shapes, types): + input_np = np.zeros(np_shape, np_type) + tensor = Tensor(input_np) + ms_types.append(tensor.dtype) + return ms_types + + +def get_dataset_base_value(dataset): + dataset_size = dataset.get_dataset_size() + batch_size = dataset.get_batch_size() + return dataset_size, batch_size + + +def dataset_send_tdt(dataset): + time.sleep(1) + dataset.send(1) + + +def get_dataset_shapes_and_types(dataset): + dataset_shapes = dataset.output_shapes() + np_types = dataset.output_types() + dataset_types = convert_type(dataset_shapes, np_types) + return dataset_shapes, dataset_types + + +class SingleOpNetwork(nn.Cell): + def __init__(self, shapes): + super(SingleOpNetwork, self).__init__() + self.shapes = tuple(shapes[0]) + self.Op_Reshape_network = P.Reshape() + + def construct(self, network_input): + return self.Op_Reshape_network(network_input, self.shapes) + + +class NetWithTDT(nn.Cell): + def __init__(self, network, dataset_types, dataset_shapes, shared_name=''): + super(NetWithTDT, self).__init__() + self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_shapes), shared_name) + self.Op_network = network + + def construct(self): + next_input, _ = self.get_next() + return self.Op_network(next_input) + + +def op_network_with_step_num(dataset, step_num): + dataset_shapes, dataset_types = get_dataset_shapes_and_types(dataset) + _, batch_size = get_dataset_base_value(dataset) + dataset = dataset.device_que() + queue_name = dataset.queue_name + + net = SingleOpNetwork(dataset_shapes) + net_with_dataset = NetWithTDT(net, dataset_types, dataset_shapes, queue_name) + # when device type is Davinci, net should has get_next operation before call init_dataset + _executor.init_dataset(dataset.queue_name, 1, batch_size, dataset_types, dataset_shapes, (), "") + dataset_send_tdt(dataset) + return op_network_with_epoch(net_with_dataset, step_num) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tdt_consume_beyond_produce(): + context.set_context(mode=context.GRAPH_MODE) + + batch_size = 64 + repeat_num = 1 + num_rows = 640 + beyond_step_num = 1000 + ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows) + + try: + iter_num = op_network_with_step_num(ds, step_num=beyond_step_num) + logger.info("out_iter_num:%s", iter_num) + assert False + except RuntimeError as e: + logger.info("when dataset batch num is less than train loop, error msg is %s", e) + assert True + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tdt_produce_beyond_consume(): + context.set_context(mode=context.GRAPH_MODE) + + batch_size = 64 + repeat_num = 1 + num_rows = 6400 + beyond_step_num = 10 + ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows) + + iter_num = op_network_with_step_num(ds, step_num=beyond_step_num) + logger.info("out_iter_num:%s", iter_num) + assert iter_num == 10