fix shape of ComputeAccidentalHits

pull/10014/head
yanzhenxiang2020 5 years ago
parent 02d191bf9d
commit e9eb1ebac8

@ -28,7 +28,7 @@
namespace mindspore {
namespace device {
namespace ascend {
std::set<std::string> kComputeDepend = {"Unique"};
std::set<std::string> kComputeDepend = {"Unique", "ComputeAccidentalHits"};
AiCpuDynamicKernel::~AiCpuDynamicKernel() {
// free dev ptr
if (ext_info_addr_dev_ == nullptr) {

@ -221,6 +221,8 @@ AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const Primiti
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -519,5 +519,31 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim
}
return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp));
}
AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// inputs: true_classes, sampled_candidates
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto shape = input->shape();
if (shape->shape().size() != 2) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
}
ShapeVector indices_shape = {Shape::SHP_ANY};
ShapeVector min_shape = {1};
ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1]};
auto indices =
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(indices_shape, min_shape, max_shape));
auto weights = std::make_shared<AbstractTensor>(kFloat32, indices_shape);
weights->set_shape(std::make_shared<Shape>(indices_shape, min_shape, max_shape));
// outputs: indices, ids, weights
AbstractBasePtrList elements = {indices, indices, weights};
return std::make_shared<AbstractTuple>(elements);
}
} // namespace abstract
} // namespace mindspore

@ -69,6 +69,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}},
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}},
{prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},

@ -101,6 +101,7 @@ inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
inline const PrimitivePtr kPrimSubAndFilter = std::make_shared<Primitive>("SubAndFilter");
inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx");
inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache");
inline const PrimitivePtr kPrimComputeAccidentalHits = std::make_shared<Primitive>("ComputeAccidentalHits");
inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable");
inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice");
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");

@ -3410,7 +3410,7 @@ class MirrorPad(PrimitiveWithInfer):
'value': None}
class ComputeAccidentalHits(PrimitiveWithInfer):
class ComputeAccidentalHits(PrimitiveWithCheck):
"""
Compute accidental hits of sampled classes which happen to match target classes.
@ -3455,17 +3455,18 @@ class ComputeAccidentalHits(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'],
outputs=['indices', 'ids', 'weights'])
validator.check_value_type("num_true", num_true, [int], self.name)
validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
self.num_true = num_true
def infer_shape(self, true_classes_shape, sampled_candidates_shape):
validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name)
validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name)
validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name)
def check_shape(self, true_classes_shape, sampled_candidates_shape):
validator.check_int(len(true_classes_shape), 2, Rel.EQ, 'dim of true_classes', self.name)
validator.check_int(len(sampled_candidates_shape), 1, Rel.EQ, 'dim of sampled_candidates', self.name)
validator.check("true_classes shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
indices_len = -1
return (indices_len,), (indices_len,), (indices_len,)
def infer_dtype(self, true_classes_type, sampled_candidates_type):
def check_dtype(self, true_classes_type, sampled_candidates_type):
validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name)
valid_types = (mstype.int32, mstype.int64)
@ -6107,13 +6108,13 @@ class CTCLoss(PrimitiveWithInfer):
>>> ctc_loss = ops.CTCLoss()
>>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
>>> print(loss)
[0.69121575 0.5381993 ]
[0.69121575 0.5381993]
>>> print(gradient)
[[[ 0.25831494 0.3623634 -0.62067937]
[ 0.25187883 0.2921483 -0.5440271 ]]
[[[0.25831494 0.3623634 -0.62067937]
[0.25187883 0.2921483 -0.5440271]]
[[ 0.43522435 0.24408469 0.07787037 ]
[ 0.29642645 0.4232373 0.06138104 ]]]
[[0.43522435 0.24408469 0.07787037]
[0.29642645 0.4232373 0.06138104]]]
"""
@prim_attr_register

@ -40,8 +40,8 @@ def test_net():
output1_expect = np.array([0, 0, 1, 1, 2, 2])
output2_expect = np.array([1, 2, 0, 4, 3, 3])
output3_expect = np.array([-3.4028235+38, -3.4028235+38, -3.4028235+38,
-3.4028235+38, -3.4028235+38, -3.4028235+38]).astype(np.float32)
output3_expect = np.array([-3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]).astype(np.float32)
assert np.array_equal(output1.asnumpy(), output1_expect)
assert np.array_equal(output2.asnumpy(), output2_expect)
assert np.array_equal(output3.asnumpy(), output3_expect)

Loading…
Cancel
Save