From e9eb1ebac8f2f082d77fa0e541409865adb62588 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Tue, 15 Dec 2020 20:29:26 +0800 Subject: [PATCH] fix shape of ComputeAccidentalHits --- .../ascend/executor/ai_cpu_dynamic_kernel.cc | 2 +- mindspore/core/abstract/infer_functions.h | 2 ++ mindspore/core/abstract/prim_nn.cc | 26 +++++++++++++++++++ .../core/abstract/primitive_infer_map.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/operations/nn_ops.py | 23 ++++++++-------- .../test_compute_accidental_hits.py | 4 +-- 7 files changed, 45 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc index 93d89dcc73..097b6d364e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace device { namespace ascend { -std::set kComputeDepend = {"Unique"}; +std::set kComputeDepend = {"Unique", "ComputeAccidentalHits"}; AiCpuDynamicKernel::~AiCpuDynamicKernel() { // free dev ptr if (ext_info_addr_dev_ == nullptr) { diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index eb2fb97107..c9fa606faf 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -221,6 +221,8 @@ AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const Primiti const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 4f7cced440..f0fb20ec08 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -519,5 +519,31 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim } return std::make_shared(arg->element(), std::make_shared(result_shp)); } + +AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // inputs: true_classes, sampled_candidates + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); + + auto shape = input->shape(); + if (shape->shape().size() != 2) { + MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; + } + ShapeVector indices_shape = {Shape::SHP_ANY}; + ShapeVector min_shape = {1}; + ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1]}; + + auto indices = + std::make_shared(input->element(), std::make_shared(indices_shape, min_shape, max_shape)); + + auto weights = std::make_shared(kFloat32, indices_shape); + weights->set_shape(std::make_shared(indices_shape, min_shape, max_shape)); + // outputs: indices, ids, weights + AbstractBasePtrList elements = {indices, indices, weights}; + return std::make_shared(elements); +} + } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index c45cbd08d4..b4cab54a7b 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -69,6 +69,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, + {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, {prim::kPrimDiv, {InferImplDiv, true}}, {prim::kPrimRealDiv, {InferImplRealDiv, true}}, {prim::kPrimShape, {InferImplShape, false}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 0a0f3d0796..35980bca49 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -101,6 +101,7 @@ inline const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); inline const PrimitivePtr kPrimSubAndFilter = std::make_shared("SubAndFilter"); inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared("MapCacheIdx"); inline const PrimitivePtr kPrimUpdateCache = std::make_shared("UpdateCache"); +inline const PrimitivePtr kPrimComputeAccidentalHits = std::make_shared("ComputeAccidentalHits"); inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared("CacheSwapTable"); inline const PrimitivePtr kPrimSlice = std::make_shared("Slice"); inline const PrimitivePtr kPrimTile = std::make_shared("Tile"); diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 42a163894f..756da0dbf9 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3410,7 +3410,7 @@ class MirrorPad(PrimitiveWithInfer): 'value': None} -class ComputeAccidentalHits(PrimitiveWithInfer): +class ComputeAccidentalHits(PrimitiveWithCheck): """ Compute accidental hits of sampled classes which happen to match target classes. @@ -3455,17 +3455,18 @@ class ComputeAccidentalHits(PrimitiveWithInfer): self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'], outputs=['indices', 'ids', 'weights']) validator.check_value_type("num_true", num_true, [int], self.name) + validator.check_number("num_true", num_true, 1, Rel.GE, self.name) self.num_true = num_true - def infer_shape(self, true_classes_shape, sampled_candidates_shape): - validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name) - validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name) - validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name) + def check_shape(self, true_classes_shape, sampled_candidates_shape): + validator.check_int(len(true_classes_shape), 2, Rel.EQ, 'dim of true_classes', self.name) + validator.check_int(len(sampled_candidates_shape), 1, Rel.EQ, 'dim of sampled_candidates', self.name) + validator.check("true_classes shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name) indices_len = -1 return (indices_len,), (indices_len,), (indices_len,) - def infer_dtype(self, true_classes_type, sampled_candidates_type): + def check_dtype(self, true_classes_type, sampled_candidates_type): validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name) valid_types = (mstype.int32, mstype.int64) @@ -6107,13 +6108,13 @@ class CTCLoss(PrimitiveWithInfer): >>> ctc_loss = ops.CTCLoss() >>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length) >>> print(loss) - [0.69121575 0.5381993 ] + [0.69121575 0.5381993] >>> print(gradient) - [[[ 0.25831494 0.3623634 -0.62067937] - [ 0.25187883 0.2921483 -0.5440271 ]] + [[[0.25831494 0.3623634 -0.62067937] + [0.25187883 0.2921483 -0.5440271]] - [[ 0.43522435 0.24408469 0.07787037 ] - [ 0.29642645 0.4232373 0.06138104 ]]] + [[0.43522435 0.24408469 0.07787037] + [0.29642645 0.4232373 0.06138104]]] """ @prim_attr_register diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py b/tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py index 74fbcaf630..4ee1a14fd0 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py @@ -40,8 +40,8 @@ def test_net(): output1_expect = np.array([0, 0, 1, 1, 2, 2]) output2_expect = np.array([1, 2, 0, 4, 3, 3]) - output3_expect = np.array([-3.4028235+38, -3.4028235+38, -3.4028235+38, - -3.4028235+38, -3.4028235+38, -3.4028235+38]).astype(np.float32) + output3_expect = np.array([-3.4028235e+38, -3.4028235e+38, -3.4028235e+38, + -3.4028235e+38, -3.4028235e+38, -3.4028235e+38]).astype(np.float32) assert np.array_equal(output1.asnumpy(), output1_expect) assert np.array_equal(output2.asnumpy(), output2_expect) assert np.array_equal(output3.asnumpy(), output3_expect)