!11794 remove useless code of dot

From: @yuan_shen_zhou
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
pull/11794/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 066ebe516e

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
# #
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -86,7 +86,6 @@ convert_object_map = {
T.floordiv: multitype_ops.floordiv, T.floordiv: multitype_ops.floordiv,
T.mod: multitype_ops.mod, T.mod: multitype_ops.mod,
T.pow: multitype_ops.pow_, T.pow: multitype_ops.pow_,
T.matmul: F.dot,
T.lshift: NO_IMPLEMENT, T.lshift: NO_IMPLEMENT,
T.rshift: NO_IMPLEMENT, T.rshift: NO_IMPLEMENT,
T.and_: multitype_ops.logical_and, T.and_: multitype_ops.logical_and,

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,7 +32,7 @@ bool CNodeHasTupleInput(const CNodePtr &cnode) {
} }
if (IsValueNode<Primitive>(inputs[i])) { if (IsValueNode<Primitive>(inputs[i])) {
// unexpected high order primitvie as cnode input when transform graph // unexpected high order primitvie as cnode input when transform graph
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString(); MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitive as input" << cnode->DebugString();
return false; return false;
} }
auto abs = inputs[i]->abstract(); auto abs = inputs[i]->abstract();

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -170,7 +170,6 @@ BuiltInTypeMap &GetMethodMap() {
{"__ge__", std::string("ge")}, // C.ge {"__ge__", std::string("ge")}, // C.ge
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as {"expand_as", std::string("expand_tensor_as")}, // C.expand_as
{"view", std::string("view")}, // C.view {"view", std::string("view")}, // C.view
{"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len, {"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
@ -352,7 +351,7 @@ void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) { if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
return; return;
} }
MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString(); MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString();
pynative_short_life_primitives_.insert(prim); pynative_short_life_primitives_.insert(prim);
pynative_new_primtives_squence_.push_back(prim->ToString()); pynative_new_primtives_squence_.push_back(prim->ToString());
} }

@ -27,8 +27,6 @@ namespace mindspore {
namespace abstract { namespace abstract {
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -34,35 +34,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
return abs_base; return abs_base;
} }
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
ShapePtr x_shp = input_x->shape();
auto x_shp_value = x_shp->shape();
ShapePtr y_shp = input_y->shape();
auto y_shp_value = y_shp->shape();
// Should be matrix which shape size is 2.
if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
<< x_shp_value.size() << ", " << y_shp_value.size() << " ";
}
if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
}
auto x_element = input_x->element();
MS_EXCEPTION_IF_NULL(x_element);
(void)x_element->Join(input_y->element());
auto param = {x_shp_value[0], y_shp_value[1]};
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
}
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim, AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: condition, true branch, false branch // Inputs: condition, true branch, false branch

@ -26,7 +26,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = { static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements // Statements
{prim::kPrimReturn, {InferImplReturn, true}}, {prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}}, {prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
{prim::kPrimIs_, {InferImplIs_, true}}, {prim::kPrimIs_, {InferImplIs_, true}},

@ -67,7 +67,6 @@ inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalO
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot"); inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute"); inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col"); inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im"); inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1"); inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -188,12 +188,6 @@ def bprop_array_to_scalar(x, out, dout):
return (F.scalar_to_array(dout),) return (F.scalar_to_array(dout),)
@bprops.register("dot")
def bprop_dot(x, y, out, dout):
"""Backpropagator for primitive `dot`."""
return F.dot(dout, F.transpose(y, (1, 0))), F.dot(F.transpose(x, (1, 0)), dout)
@bprops.register("reshape") @bprops.register("reshape")
def bprop_reshape(xs, shp, out, dout): def bprop_reshape(xs, shp, out, dout):
"""Backpropagator for primitive `reshape`.""" """Backpropagator for primitive `reshape`."""

@ -142,7 +142,6 @@ in_dict = Primitive("in_dict")
not_in_dict = Primitive("not_in_dict") not_in_dict = Primitive("not_in_dict")
mixed_precision_cast = Primitive("mixed_precision_cast") mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args = Primitive('BroadcastGradientArgs') broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce') array_reduce = Primitive('array_reduce')
zeros_like = P.ZerosLike() zeros_like = P.ZerosLike()
distribute = Primitive('distribute') distribute = Primitive('distribute')

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -287,11 +287,6 @@ TEST_F(TestOps, TransposeTest) {
ASSERT_EQ(prim->name(), kPrimTranspose->name()); ASSERT_EQ(prim->name(), kPrimTranspose->name());
} }
TEST_F(TestOps, DotTest) {
auto prim = std::make_shared<Primitive>("dot");
ASSERT_EQ(prim->name(), kPrimDot->name());
}
TEST_F(TestOps, Im2ColTest) { TEST_F(TestOps, Im2ColTest) {
auto prim = std::make_shared<Primitive>("im2col"); auto prim = std::make_shared<Primitive>("im2col");
ASSERT_EQ(prim->name(), kPrimIm2Col->name()); ASSERT_EQ(prim->name(), kPrimIm2Col->name());

@ -169,11 +169,6 @@ TEST_F(TestAD, test_prim_array_to_scalar) {
AssertExpect("test_prim_array_to_scalar", dg); AssertExpect("test_prim_array_to_scalar", dg);
} }
TEST_F(TestAD, test_prim_dot) {
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDot), resourcePtr);
AssertExpect("test_prim_dot", dg);
}
TEST_F(TestAD, test_prim_distribute) { TEST_F(TestAD, test_prim_distribute) {
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDistribute), resourcePtr); FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDistribute), resourcePtr);
AssertExpect("test_prim_distribute", dg); AssertExpect("test_prim_distribute", dg);

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -291,22 +291,6 @@ TEST_F(TestPrim, test_J_2) {
ASSERT_TRUE(res_J_1 != nullptr); ASSERT_TRUE(res_J_1 != nullptr);
} }
TEST_F(TestPrim, test_dot) {
auto dot = std::make_shared<Primitive>("dot");
FuncGraphPtr func_graph = MakeFuncGraph(dot, 2);
auto a1 = UTPrimUtils::ArrayFloat64Of({2, 3});
auto a2 = UTPrimUtils::ArrayFloat64Of({3, 4});
std::vector<int64_t> expectedA = {2, 4};
auto expected = UTPrimUtils::ArrayFloat64Of({2, 4});
AbstractBasePtrList args_spec_list = {a1, a2};
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
}
// tail half // tail half
TEST_F(TestPrim, test_switch1) { TEST_F(TestPrim, test_switch1) {
PrimitivePtr switch_ = std::make_shared<Primitive>("switch"); PrimitivePtr switch_ = std::make_shared<Primitive>("switch");

Loading…
Cancel
Save