fix shape of CTCGreedyDecoder

pull/10086/head
yanzhenxiang2020 4 years ago
parent 4a88def82c
commit b8b608f672

@ -32,6 +32,7 @@ namespace mindspore {
// op name. Op which not exists in operator/ops.h, so define it's name here // op name. Op which not exists in operator/ops.h, so define it's name here
constexpr auto kUniqueOpName = "Unique"; constexpr auto kUniqueOpName = "Unique";
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
constexpr auto kFour2FiveOpName = "Four2Five"; constexpr auto kFour2FiveOpName = "Four2Five";
constexpr auto kFive2FourOpName = "Five2Four"; constexpr auto kFive2FourOpName = "Five2Four";
constexpr auto kConv2DOpName = "Conv2D"; constexpr auto kConv2DOpName = "Conv2D";
@ -486,7 +487,7 @@ const std::set<std::string> kHWSpecialFormatSet = {
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
kPadAndShiftOpName}; kPadAndShiftOpName, kCTCGreedyDecoderOpName};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};

@ -205,6 +205,8 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -486,6 +486,42 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim
return std::make_shared<AbstractTuple>(elements); return std::make_shared<AbstractTuple>(elements);
} }
AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// inputs: inputs, sequence_length
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto shape = input->shape();
if (shape->shape().size() != 3) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 3.";
}
ShapeVector indices_shape = {Shape::SHP_ANY, 2};
ShapeVector min_shape = {1, 2};
ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1], 2};
auto decoded_indices =
std::make_shared<AbstractTensor>(kInt64, std::make_shared<Shape>(indices_shape, min_shape, max_shape));
ShapeVector values_shape = {Shape::SHP_ANY};
ShapeVector values_min_shape = {1};
ShapeVector values_max_shape = {shape->shape()[0] * shape->shape()[1]};
ShapePtr values_shapes = std::make_shared<Shape>(values_shape, values_min_shape, values_max_shape);
auto decoded_values = std::make_shared<AbstractTensor>(kInt64, values_shapes);
ShapeVector decoded_shape_shape = {2};
auto decoded_shape = std::make_shared<AbstractTensor>(kInt64, decoded_shape_shape);
ShapeVector log_probability_shape = {shape->shape()[1], 1};
auto log_probability =
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(log_probability_shape));
// outputs: decoded_indices, decoded_values, decoded_shape, log_probability
AbstractBasePtrList elements = {decoded_indices, decoded_values, decoded_shape, log_probability};
return std::make_shared<AbstractTuple>(elements);
}
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();

@ -120,6 +120,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}},
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}},
{prim::kPrimSGD, {InferImplSGD, true}}, {prim::kPrimSGD, {InferImplSGD, true}},
{prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}},
// Others // Others
{prim::kPrimIdentity, {InferImplIdentity, true}}, {prim::kPrimIdentity, {InferImplIdentity, true}},
// Set impl to null as it will use PartialEvaluator; // Set impl to null as it will use PartialEvaluator;

@ -160,6 +160,7 @@ inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive
inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput"); inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput");
inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter"); inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter");
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared<Primitive>("CTCGreedyDecoder");
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =

@ -6195,7 +6195,7 @@ class CTCLoss(PrimitiveWithInfer):
return inputs, inputs return inputs, inputs
class CTCGreedyDecoder(PrimitiveWithInfer): class CTCGreedyDecoder(PrimitiveWithCheck):
""" """
Performs greedy decoding on the logits given in inputs. Performs greedy decoding on the logits given in inputs.
@ -6221,29 +6221,22 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
containing sequence log-probability, has the same type as `inputs`. containing sequence log-probability, has the same type as `inputs`.
Examples: Examples:
>>> class CTCGreedyDecoderNet(nn.Cell):
... def __init__(self):
... super(CTCGreedyDecoderNet, self).__init__()
... self.ctc_greedy_decoder = P.CTCGreedyDecoder()
... self.assert_op = ops.Assert(300)
...
... def construct(self, inputs, sequence_length):
... out = self.ctc_greedy_decoder(inputs,sequence_length)
... self.assert_op(True, (out[0], out[1], out[2], out[3]))
... return out[2]
...
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> net = CTCGreedyDecoderNet() >>> ctc_greedy_decoder = ops.CTCGreedyDecoder()
>>> output = net(inputs, sequence_length) >>> out1, out2, out3, out4 = ctc_greedy_decoder(inputs, sequence_length)
>>> print(output) >>> print(out1, out2, out3, out4)
[[0 0] [0 1] [1 0]]
[0 1 0]
[2 2]
[[-0.7443749] [0.18251707]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, merge_repeated=True): def __init__(self, merge_repeated=True):
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
def infer_shape(self, inputs_shape, sequence_length_shape): def check_shape(self, inputs_shape, sequence_length_shape):
validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name) validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name)
validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name) validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name)
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
@ -6255,7 +6248,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
log_probability_shape = [inputs_shape[1], 1] log_probability_shape = [inputs_shape[1], 1]
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
def infer_dtype(self, inputs_dtype, sequence_length_dtype): def check_dtype(self, inputs_dtype, sequence_length_dtype):
validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name) validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name)
validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name) validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name)
decoded_type = mstype.tensor_type(mstype.int64) decoded_type = mstype.tensor_type(mstype.int64)

Loading…
Cancel
Save