From 5cccfbc61ba4e67de63eecacd564373b7ddb0e3a Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Fri, 8 May 2020 15:43:17 +0800 Subject: [PATCH] Add TransShape Operator. --- mindspore/ccsrc/transform/convert.cc | 2 ++ mindspore/ccsrc/transform/op_declare.cc | 6 ++++++ mindspore/ccsrc/transform/op_declare.h | 3 +++ mindspore/ops/_grad/grad_array_ops.py | 10 ++++++++++ mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/array_ops.py | 25 +++++++++++++++++++++++++ tests/ut/python/ops/test_ops.py | 6 ++++++ 7 files changed, 53 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 3f6b31303c..f88e31fcd2 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub"; const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus"; const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; const char kNameReshape[] = "Reshape"; +const char kNameTransShape[] = "TransShape"; const char kNameRealDiv[] = "RealDiv"; const char kNameTile[] = "Tile"; const char kNameCos[] = "Cos"; @@ -242,6 +243,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, {string(kNameReshape), ADPT_DESC(Reshape)}, + {string(kNameTransShape), ADPT_DESC(TransShape)}, {string(kNameFlattenGrad), ADPT_DESC(Reshape)}, {prim::kPrimFlatten->name(), ADPT_DESC(Flatten)}, {string(kNameAddN), ADPT_DESC(AddN)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index cac526f1fb..fd8ce624a9 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}}; ATTR_MAP(Reshape) = EMPTY_ATTR_MAP; OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}}; +// TransShape +INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(TransShape) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}}; + // BiasAdd INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}}; ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index f64dc7b671..baa819f71f 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD) DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) DECLARE_OP_ADAPTER(Reshape) DECLARE_OP_USE_OUTPUT(Reshape) +DECLARE_OP_ADAPTER(TransShape) +DECLARE_OP_USE_INPUT_ATTR(TransShape) +DECLARE_OP_USE_OUTPUT(TransShape) DECLARE_OP_ADAPTER(Iou) DECLARE_OP_USE_OUTPUT(Iou) DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index a2a808781e..e216a4f0d0 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self): dx = reverse_sequence_grad(dout, seq_lengths) return dx, zeros_like(seq_lengths) return bprop + + +@bprop_getters.register(P.TransShape) +def get_bprop_trans_shape(self): + """Generate bprop for TransShape""" + op = P.TransShape() + def bprop(x, shape, out, dout): + dx = op(dout, shape_op(x)) + return (dx, zeros_like(shape)) + return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index b2d0fc7382..9053abab06 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, - Shape, Size, Slice, Split, + Shape, Size, Slice, Split, TransShape, Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 8e9ecfea95..70665d3367 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer): validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) return x + + +class TransShape(PrimitiveWithInfer): + """ + Transform the shape of input tensor to target shape. + + Inputs: + - **input_x** (Tensor) - A input tensor. + - **out_shape** (tuple[int]) - The shape of output data. + + Outputs: + Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`. + """ + @prim_attr_register + def __init__(self): + self.__setattr_flag__ = True + + def __infer__(self, x, shape): + shp = shape['value'] + dtype = x['dtype'] + validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name) + self.add_prim_attr('out_shape', tuple(shp)) + return {'shape': shp, + 'dtype': dtype, + 'value': None} diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 5262145c80..c4330e3f65 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1865,6 +1865,12 @@ test_case_array_ops = [ Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], 'skip': ['backward'], }), + ('TransShape', { + 'block': P.TransShape(), + 'desc_const': [(1, 12, 24, 24)], + 'desc_inputs': [[1, 3, 24, 24]], + 'desc_bprop': [[1, 12, 24, 24]], + }), ] test_case_other_ops = [