Add TransShape Operator.

pull/2843/head
leilei_snow 5 years ago committed by leilei_snow
parent f42e36b629
commit 5cccfbc61b

@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub";
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
const char kNameReshape[] = "Reshape";
const char kNameTransShape[] = "TransShape";
const char kNameRealDiv[] = "RealDiv";
const char kNameTile[] = "Tile";
const char kNameCos[] = "Cos";
@ -242,6 +243,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
{string(kNameReshape), ADPT_DESC(Reshape)},
{string(kNameTransShape), ADPT_DESC(TransShape)},
{string(kNameFlattenGrad), ADPT_DESC(Reshape)},
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
{string(kNameAddN), ADPT_DESC(AddN)},

@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};
// TransShape
INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(TransShape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}};
// BiasAdd
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};

@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(TransShape)
DECLARE_OP_USE_INPUT_ATTR(TransShape)
DECLARE_OP_USE_OUTPUT(TransShape)
DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)

@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self):
dx = reverse_sequence_grad(dout, seq_lengths)
return dx, zeros_like(seq_lengths)
return bprop
@bprop_getters.register(P.TransShape)
def get_bprop_trans_shape(self):
"""Generate bprop for TransShape"""
op = P.TransShape()
def bprop(x, shape, out, dout):
dx = op(dout, shape_op(x))
return (dx, zeros_like(shape))
return bprop

@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Shape, Size, Slice, Split, TransShape,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,

@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
return x
class TransShape(PrimitiveWithInfer):
"""
Transform the shape of input tensor to target shape.
Inputs:
- **input_x** (Tensor) - A input tensor.
- **out_shape** (tuple[int]) - The shape of output data.
Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`.
"""
@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True
def __infer__(self, x, shape):
shp = shape['value']
dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp,
'dtype': dtype,
'value': None}

@ -1865,6 +1865,12 @@ test_case_array_ops = [
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'],
}),
('TransShape', {
'block': P.TransShape(),
'desc_const': [(1, 12, 24, 24)],
'desc_inputs': [[1, 3, 24, 24]],
'desc_bprop': [[1, 12, 24, 24]],
}),
]
test_case_other_ops = [

Loading…
Cancel
Save