From 4fa6238dde8622a2d0e5eaa067df41e4f8fef525 Mon Sep 17 00:00:00 2001 From: tom__chen Date: Mon, 16 Nov 2020 14:52:51 -0500 Subject: [PATCH] change Split op to extend PrimitiveWithCheck --- .../gpu/arrays/split_gpu_kernel.h | 4 - mindspore/core/abstract/infer_functions.h | 2 + mindspore/core/abstract/prim_arrays.cc | 43 ++++++++++ .../core/abstract/primitive_infer_map.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/operations/array_ops.py | 28 ++----- tests/st/ops/gpu/test_split.py | 82 +++++++++++++++++++ 7 files changed, 134 insertions(+), 27 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h index a487e5febc..a36c75c6b5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h @@ -132,10 +132,6 @@ class SplitGpuFwdKernel : public GpuKernel { MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_]; return false; } - if (input_shape[axis_] % output_num_ != 0) { - MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_]; - return false; - } if (output_num_ != output_num) { MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_; return false; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 4ac2b659ba..98ddb7eae6 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -259,6 +259,8 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 1de54a9814..d2a62d9b77 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -695,5 +695,48 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt auto ret = std::make_shared(x->element(), std::make_shared(shape, shape_min, shape_max)); return ret; } + +AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); + ShapeVector x_shape = input_x->shape()->shape(); + ShapeVector x_shape_min = input_x->shape()->min_shape(); + if (x_shape_min.empty()) { + x_shape_min = x_shape; + } + ShapeVector x_shape_max = input_x->shape()->max_shape(); + if (x_shape_max.empty()) { + x_shape_max = x_shape; + } + int64_t rank = SizeToLong(x_shape.size()); + + ValuePtr axis = primitive->GetAttr("axis"); + int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank); + axis_value = GetPositiveAxis(axis_value, LongToSize(rank)); + int64_t output_num_value = primitive->GetAttr("output_num")->cast()->value(); + if ((x_shape[axis_value] != Shape::SHP_ANY) && (x_shape[axis_value] % output_num_value != 0)) { + MS_LOG(EXCEPTION) << "x_shape[" << axis_value << "] = " << x_shape[axis_value] + << " must be divisible by output_num = " << output_num_value; + } + + ShapeVector output_shape = x_shape; + if (output_shape[axis_value] != Shape::SHP_ANY) { + output_shape[axis_value] = static_cast(x_shape[axis_value] / output_num_value); + } + ShapeVector output_shape_min = x_shape_min; + output_shape_min[axis_value] = static_cast(x_shape_min[axis_value] / output_num_value); + ShapeVector output_shape_max = x_shape_max; + output_shape_max[axis_value] = static_cast(x_shape_max[axis_value] / output_num_value); + + AbstractBasePtrList output_list; + for (int64_t i = 0; i < output_num_value; ++i) { + auto output = input_x->Broaden(); + output->set_shape(std::make_shared(output_shape, output_shape_min, output_shape_max)); + output_list.push_back(output); + } + return std::make_shared(output_list); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index eeb21b9fb0..e083383bb3 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -69,6 +69,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, {prim::kPrimTranspose, {InferImplTranspose, true}}, {prim::kPrimReshape, {InferImplReshape, true}}, + {prim::kPrimSplit, {InferImplSplit, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 32b699d785..2ec7b17a4f 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -118,6 +118,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared("Dynam inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared("DynamicGRUV2Grad"); inline const PrimitivePtr kPrimScatterAdd = std::make_shared("ScatterAdd"); inline const PrimitivePtr kPrimScatterUpdate = std::make_shared("ScatterUpdate"); +inline const PrimitivePtr kPrimSplit = std::make_shared("Split"); // NN inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a20bc73b40..c82273a981 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -873,18 +873,17 @@ class UniqueWithPad(PrimitiveWithInfer): return out -class Split(PrimitiveWithInfer): +class Split(PrimitiveWithCheck): """ Splits the input tensor into output_num of tensors along the given axis and output numbers. Args: axis (int): Index of the split position. Default: 0. - output_num (int): The number of output tensors. Default: 1. + output_num (int): The number of output tensors. Must be postive int. Default: 1. Raises: ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)), - or if the `output_num` is less than or equal to 0, or if the - dimension which to split cannot be evenly divided by `output_num`. + or if the `output_num` is less than or equal to 0. Inputs: - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. @@ -913,32 +912,15 @@ class Split(PrimitiveWithInfer): """Initialize Split""" validator.check_value_type("axis", axis, [int], self.name) validator.check_value_type("output_num", output_num, [int], self.name) + validator.check_positive_int(output_num, "output_num", self.name) self.axis = axis self.output_num = output_num - def __infer__(self, x): + def __check__(self, x): validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) x_shape = list(x['shape']) dim = len(x_shape) validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) - validator.check_positive_int(self.output_num, "output_num", self.name) - output_valid_check = x_shape[self.axis] % self.output_num - if output_valid_check != 0: - raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" - f" output_num {self.output_num}") - - x_shape[self.axis] = int(x_shape[self.axis] / self.output_num) - out_shapes = [] - out_dtypes = [] - for _ in range(self.output_num): - out_shapes.append(tuple(x_shape)) - out_dtypes.append(x['dtype']) - out_shapes = tuple(out_shapes) - out_dtypes = tuple(out_dtypes) - out = {'shape': out_shapes, - 'dtype': out_dtypes, - 'value': None} - return out class Rank(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_split.py b/tests/st/ops/gpu/test_split.py index f9e3cfce2f..c403786ad5 100644 --- a/tests/st/ops/gpu/test_split.py +++ b/tests/st/ops/gpu/test_split.py @@ -18,6 +18,7 @@ import pytest import mindspore.context as context from mindspore import Tensor import mindspore.nn as nn +from mindspore.ops.operations import _inner_ops as inner from mindspore.ops import operations as P @@ -30,6 +31,18 @@ class Net(nn.Cell): return self.split(x) +class NetDynamic(nn.Cell): + def __init__(self, axis=0, out_nums=1): + super(NetDynamic, self).__init__() + self.conv = inner.GpuConvertToDynamicShape() + self.split = P.Split(axis, out_nums) + + def construct(self, x): + x_conv = self.conv(x) + x_split = self.split(x_conv) + return x_split + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -47,6 +60,9 @@ def test_split(): assert (out.asnumpy() == x[i]).all() +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_split_4d(): x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) y = np.split(x_np, 3, axis=1) @@ -56,3 +72,69 @@ def test_split_4d(): for i, out in enumerate(outputs): assert (out.asnumpy() == y[i]).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_split_dynamic(): + x = np.array([[[1, -1, 1], [2, -2, 2]], + [[3, -3, 3], [4, -4, 4]], + [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) + + net = NetDynamic(0, 3) + x_split = net(Tensor(x)) + for i, out in enumerate(x_split): + assert (out.asnumpy() == x[i]).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_split_dynamic_axis1(): + x = np.array([[[1, -1, 1], [2, -2, 2]], + [[3, -3, 3], [4, -4, 4]], + [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) + y = np.split(x, 2, axis=1) + + net = NetDynamic(1, 2) + x_split = net(Tensor(x)) + for i, out in enumerate(x_split): + assert (out.asnumpy() == y[i]).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_split_dynamic_axis2(): + x = np.array([[[1, -1, 1], [2, -2, 2]], + [[3, -3, 3], [4, -4, 4]], + [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) + y = np.split(x, 3, axis=2) + + net = NetDynamic(2, 3) + x_split = net(Tensor(x)) + for i, out in enumerate(x_split): + assert (out.asnumpy() == y[i]).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_split_invalid_input(): + with pytest.raises(TypeError): + _ = Net(0.1, 3) + + with pytest.raises(TypeError): + _ = Net(0, 3.0) + + with pytest.raises(ValueError): + _ = Net(0, -3) + + x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) + split_net = Net(2, 2) + with pytest.raises(ValueError): + _ = split_net(Tensor(x)) + + with pytest.raises(TypeError): + _ = split_net(x)