!2843 Add TransShape operator

Merge pull request !2843 from fanglei/trans_shape
pull/2843/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 17319d8dfd

@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub";
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
const char kNameReshape[] = "Reshape";
const char kNameTransShape[] = "TransShape";
const char kNameRealDiv[] = "RealDiv";
const char kNameTile[] = "Tile";
const char kNameCos[] = "Cos";
@ -242,6 +243,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
{string(kNameReshape), ADPT_DESC(Reshape)},
{string(kNameTransShape), ADPT_DESC(TransShape)},
{string(kNameFlattenGrad), ADPT_DESC(Reshape)},
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
{string(kNameAddN), ADPT_DESC(AddN)},

@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};
// TransShape
INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(TransShape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}};
// BiasAdd
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};

@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(TransShape)
DECLARE_OP_USE_INPUT_ATTR(TransShape)
DECLARE_OP_USE_OUTPUT(TransShape)
DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)

@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self):
dx = reverse_sequence_grad(dout, seq_lengths)
return dx, zeros_like(seq_lengths)
return bprop
@bprop_getters.register(P.TransShape)
def get_bprop_trans_shape(self):
"""Generate bprop for TransShape"""
op = P.TransShape()
def bprop(x, shape, out, dout):
dx = op(dout, shape_op(x))
return (dx, zeros_like(shape))
return bprop

@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Shape, Size, Slice, Split, TransShape,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,

@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
return x
class TransShape(PrimitiveWithInfer):
"""
Transform the shape of input tensor to target shape.
Inputs:
- **input_x** (Tensor) - A input tensor.
- **out_shape** (tuple[int]) - The shape of output data.
Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`.
"""
@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True
def __infer__(self, x, shape):
shp = shape['value']
dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp,
'dtype': dtype,
'value': None}

@ -1865,6 +1865,12 @@ test_case_array_ops = [
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'],
}),
('TransShape', {
'block': P.TransShape(),
'desc_const': [(1, 12, 24, 24)],
'desc_inputs': [[1, 3, 24, 24]],
'desc_bprop': [[1, 12, 24, 24]],
}),
]
test_case_other_ops = [

Loading…
Cancel
Save