|
|
|
@ -24,7 +24,8 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const {
|
|
|
|
|
std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv,
|
|
|
|
|
const AnfNodePtr &final_node) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
|
auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]);
|
|
|
|
|
auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]);
|
|
|
|
@ -37,7 +38,12 @@ std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const Equ
|
|
|
|
|
auto mul3_x = utils::cast<AnfNodePtr>((*equiv)[mul3_x_]);
|
|
|
|
|
auto mul4_x = utils::cast<AnfNodePtr>((*equiv)[mul4_x_]);
|
|
|
|
|
auto add2_y = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
|
|
|
|
|
auto prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayOpName);
|
|
|
|
|
PrimitivePtr prim = nullptr;
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) {
|
|
|
|
|
prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayAssignOpName);
|
|
|
|
|
} else {
|
|
|
|
|
prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayOpName);
|
|
|
|
|
}
|
|
|
|
|
return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -141,18 +147,152 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const {
|
|
|
|
|
return sub0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const {
|
|
|
|
|
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
|
|
|
|
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
|
|
|
|
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
|
|
|
|
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
|
|
|
|
VectorRef square0({prim::kPrimSquare, input0_});
|
|
|
|
|
VectorRef add0({add0_var_, mul0, mul1});
|
|
|
|
|
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
|
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
|
|
|
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
|
|
|
|
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
|
|
|
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const {
|
|
|
|
|
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
|
|
|
|
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
|
|
|
|
VectorRef mul0({prim::kPrimMul, input2_, mul0_x_});
|
|
|
|
|
VectorRef mul1({prim::kPrimMul, input0_, mul1_x_});
|
|
|
|
|
VectorRef square0({prim::kPrimSquare, input0_});
|
|
|
|
|
VectorRef add0({add0_var_, mul0, mul1});
|
|
|
|
|
VectorRef mul2({prim::kPrimMul, input1_, mul2_x_});
|
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
|
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
|
|
|
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
|
|
|
|
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
|
|
|
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const {
|
|
|
|
|
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
|
|
|
|
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
|
|
|
|
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
|
|
|
|
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
|
|
|
|
VectorRef square0({prim::kPrimSquare, input0_});
|
|
|
|
|
VectorRef add0({add0_var_, mul0, mul1});
|
|
|
|
|
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
|
|
|
|
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
|
|
|
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
|
|
|
|
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
|
|
|
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const {
|
|
|
|
|
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
|
|
|
|
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
|
|
|
|
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
|
|
|
|
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
|
|
|
|
VectorRef square0({prim::kPrimSquare, input0_});
|
|
|
|
|
VectorRef add0({add0_var_, mul0, mul1});
|
|
|
|
|
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
|
|
|
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
|
|
|
|
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
|
|
|
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const {
|
|
|
|
|
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
|
|
|
|
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
|
|
|
|
|
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
|
|
|
|
|
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
|
|
|
|
|
VectorRef square0({prim::kPrimSquare, input0_});
|
|
|
|
|
VectorRef add0({add0_var_, mul0, mul1});
|
|
|
|
|
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
|
|
|
|
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
|
|
|
|
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
|
|
|
|
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
|
|
|
|
|
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &equiv) const {
|
|
|
|
|
if (graph == nullptr || node == nullptr || equiv == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
|
|
|
|
auto sub0 = node;
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
|
|
|
|
auto iter_sub0 = (*equiv).find(sub0_var_);
|
|
|
|
|
if (iter_sub0 == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched.";
|
|
|
|
|
}
|
|
|
|
|
sub0 = utils::cast<AnfNodePtr>(iter_sub0->second);
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub0);
|
|
|
|
|
if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv);
|
|
|
|
|
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv, node);
|
|
|
|
|
auto fusion_node = graph->NewCNode(inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fusion_node);
|
|
|
|
|
fusion_node->set_scope(node->scope());
|
|
|
|
|
fusion_node->set_scope(sub0->scope());
|
|
|
|
|
|
|
|
|
|
auto iter_add0 = (*equiv).find(add0_var_);
|
|
|
|
|
if (iter_add0 == (*equiv).end()) {
|
|
|
|
@ -167,9 +307,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
|
|
|
|
|
auto add1 = utils::cast<AnfNodePtr>(iter_add1->second);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(add1);
|
|
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0),
|
|
|
|
|
AnfAlgo::GetOutputInferDataType(node, 0)};
|
|
|
|
|
AnfAlgo::GetOutputInferDataType(sub0, 0)};
|
|
|
|
|
auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0),
|
|
|
|
|
AnfAlgo::GetOutputInferShape(node, 0)};
|
|
|
|
|
AnfAlgo::GetOutputInferShape(sub0, 0)};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get());
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> fusion_node_outputs;
|
|
|
|
|