|
|
|
@ -24,40 +24,57 @@
|
|
|
|
|
#include "pre_activate/common/helper.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(add);
|
|
|
|
|
|
|
|
|
|
for (size_t index = 1; index < add->size(); ++index) {
|
|
|
|
|
auto input = add->input(index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
|
if (input->isa<CNode>()) {
|
|
|
|
|
auto cnode = input->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) {
|
|
|
|
|
if (!opt::IsUsedByOthers(graph, cnode)) {
|
|
|
|
|
*mul = cnode;
|
|
|
|
|
*mul_index = index;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace opt {
|
|
|
|
|
const BaseRef MulAddFusion::DefinePattern() const {
|
|
|
|
|
VarPtr mul_x_ = std::make_shared<Var>();
|
|
|
|
|
VarPtr mul_y_ = std::make_shared<Var>();
|
|
|
|
|
VarPtr add_y_ = std::make_shared<Var>();
|
|
|
|
|
|
|
|
|
|
VectorRef mul({prim::kPrimMul, mul_x_, mul_y_});
|
|
|
|
|
VectorRef add({prim::kPrimTensorAdd, mul, add_y_});
|
|
|
|
|
return add;
|
|
|
|
|
VarPtr x = std::make_shared<Var>();
|
|
|
|
|
VarPtr y = std::make_shared<Var>();
|
|
|
|
|
VectorRef pattern({prim::kPrimTensorAdd, x, y});
|
|
|
|
|
return pattern;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
|
|
|
|
|
if (graph == nullptr || node == nullptr || equiv == nullptr) {
|
|
|
|
|
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
|
|
|
|
if (graph == nullptr || node == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto add = node->cast<CNodePtr>();
|
|
|
|
|
if (add == nullptr || add->inputs().size() != kAddInputNum) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto mul_anf = add->input(1);
|
|
|
|
|
if (mul_anf == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto mul = mul_anf->cast<CNodePtr>();
|
|
|
|
|
if (mul == nullptr || mul->inputs().size() != kMulInputNum) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (IsUsedByOthers(graph, mul)) {
|
|
|
|
|
MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse";
|
|
|
|
|
CNodePtr mul = nullptr;
|
|
|
|
|
size_t mul_index = 0;
|
|
|
|
|
if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto prim = std::make_shared<Primitive>(kFusedMulAddOpName);
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul->input(1), mul->input(2), add->input(2)};
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
|
|
|
|
for (size_t index = 1; index < mul->size(); ++index) {
|
|
|
|
|
inputs.push_back(mul->input(index));
|
|
|
|
|
}
|
|
|
|
|
inputs.push_back(add->input(add->size() - mul_index));
|
|
|
|
|
auto fusion_node = graph->NewCNode(inputs);
|
|
|
|
|
fusion_node->set_scope(add->scope());
|
|
|
|
|
fusion_node->set_abstract(add->abstract());
|
|
|
|
|