diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5669722fc1..50ce81c912 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -32,6 +32,7 @@ namespace mindspore { // op name. Op which not exists in operator/ops.h, so define it's name here constexpr auto kUniqueOpName = "Unique"; constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; +constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder"; constexpr auto kFour2FiveOpName = "Four2Five"; constexpr auto kFive2FourOpName = "Five2Four"; constexpr auto kConv2DOpName = "Conv2D"; @@ -486,7 +487,7 @@ const std::set kHWSpecialFormatSet = { const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, - kPadAndShiftOpName}; + kPadAndShiftOpName, kCTCGreedyDecoderOpName}; const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 92dc45a53b..d2f81ddab1 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -205,6 +205,8 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 61aa4e49e2..d745bca2de 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -486,6 +486,42 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim return std::make_shared(elements); } +AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // inputs: inputs, sequence_length + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); + + auto shape = input->shape(); + if (shape->shape().size() != 3) { + MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 3."; + } + + ShapeVector indices_shape = {Shape::SHP_ANY, 2}; + ShapeVector min_shape = {1, 2}; + ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1], 2}; + auto decoded_indices = + std::make_shared(kInt64, std::make_shared(indices_shape, min_shape, max_shape)); + + ShapeVector values_shape = {Shape::SHP_ANY}; + ShapeVector values_min_shape = {1}; + ShapeVector values_max_shape = {shape->shape()[0] * shape->shape()[1]}; + ShapePtr values_shapes = std::make_shared(values_shape, values_min_shape, values_max_shape); + auto decoded_values = std::make_shared(kInt64, values_shapes); + + ShapeVector decoded_shape_shape = {2}; + auto decoded_shape = std::make_shared(kInt64, decoded_shape_shape); + + ShapeVector log_probability_shape = {shape->shape()[1], 1}; + auto log_probability = + std::make_shared(input->element(), std::make_shared(log_probability_shape)); + + // outputs: decoded_indices, decoded_values, decoded_shape, log_probability + AbstractBasePtrList elements = {decoded_indices, decoded_values, decoded_shape, log_probability}; + return std::make_shared(elements); +} + AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 7fcf58fe76..8fc454c3fe 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -120,6 +120,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, {prim::kPrimSGD, {InferImplSGD, true}}, + {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}}, // Others {prim::kPrimIdentity, {InferImplIdentity, true}}, // Set impl to null as it will use PartialEvaluator; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8ab89abffc..fb19b204b2 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -160,6 +160,7 @@ inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv3DBackpropInput"); inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared("Conv3DBackpropFilter"); inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); +inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared("CTCGreedyDecoder"); inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = std::make_shared("DepthwiseConv2dNativeBackpropFilter"); inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 784598b8df..79d2e0eab0 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6195,7 +6195,7 @@ class CTCLoss(PrimitiveWithInfer): return inputs, inputs -class CTCGreedyDecoder(PrimitiveWithInfer): +class CTCGreedyDecoder(PrimitiveWithCheck): """ Performs greedy decoding on the logits given in inputs. @@ -6221,29 +6221,22 @@ class CTCGreedyDecoder(PrimitiveWithInfer): containing sequence log-probability, has the same type as `inputs`. Examples: - >>> class CTCGreedyDecoderNet(nn.Cell): - ... def __init__(self): - ... super(CTCGreedyDecoderNet, self).__init__() - ... self.ctc_greedy_decoder = P.CTCGreedyDecoder() - ... self.assert_op = ops.Assert(300) - ... - ... def construct(self, inputs, sequence_length): - ... out = self.ctc_greedy_decoder(inputs,sequence_length) - ... self.assert_op(True, (out[0], out[1], out[2], out[3])) - ... return out[2] - ... >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) - >>> net = CTCGreedyDecoderNet() - >>> output = net(inputs, sequence_length) - >>> print(output) + >>> ctc_greedy_decoder = ops.CTCGreedyDecoder() + >>> out1, out2, out3, out4 = ctc_greedy_decoder(inputs, sequence_length) + >>> print(out1, out2, out3, out4) + [[0 0] [0 1] [1 0]] + [0 1 0] + [2 2] + [[-0.7443749] [0.18251707]] """ @prim_attr_register def __init__(self, merge_repeated=True): self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) - def infer_shape(self, inputs_shape, sequence_length_shape): + def check_shape(self, inputs_shape, sequence_length_shape): validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name) validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name) validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', @@ -6255,7 +6248,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer): log_probability_shape = [inputs_shape[1], 1] return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape - def infer_dtype(self, inputs_dtype, sequence_length_dtype): + def check_dtype(self, inputs_dtype, sequence_length_dtype): validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name) validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name) decoded_type = mstype.tensor_type(mstype.int64)