From 29bb99d55c792dd4d6db41369b82be8a07c89050 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Thu, 17 Dec 2020 16:26:59 +0800 Subject: [PATCH] fix ComputeAccidentalHits example --- mindspore/core/abstract/prim_nn.cc | 2 +- mindspore/ops/operations/nn_ops.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index f0fb20ec08..61aa4e49e2 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -529,7 +529,7 @@ AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const auto shape = input->shape(); if (shape->shape().size() != 2) { - MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; + MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 2."; } ShapeVector indices_shape = {Shape::SHP_ANY}; ShapeVector min_shape = {1}; diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 820dd7beed..f1e160e10b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3425,8 +3425,9 @@ class ComputeAccidentalHits(PrimitiveWithCheck): >>> sampler = ops.ComputeAccidentalHits(2) >>> output1, output2, output3 = sampler(Tensor(x), Tensor(y)) >>> print(output1, output2, output3) - [0, 0, 1, 1, 2, 2], [1, 2, 0, 4, 3, 3], - [-3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38] + [0 0 1 1 2 2] + [1 2 0 4 3 3] + [-3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38] """