|
|
|
@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
|
|
|
|
|
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
|
|
|
|
|
class ConstantDuplicateMul : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
Reset();
|
|
|
|
|
// {prim::kPrimMul, Tensor1, {...}}
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
|
|
|
|
|
if (vnode_ == nullptr || cnode_ == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto tensor1 = vnode_;
|
|
|
|
|
auto mul = cnode_;
|
|
|
|
|
|
|
|
|
|
Reset();
|
|
|
|
|
// {prim::kPrimMul, Tensor2, {...}}
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
|
|
|
|
|
if (vnode_ == nullptr || cnode_ == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto tensor2 = vnode_;
|
|
|
|
|
auto cnode = cnode_;
|
|
|
|
|
|
|
|
|
|
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
|
|
|
|
|
return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override {
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(node)) {
|
|
|
|
|
vnode_ = node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsCNode(node)) {
|
|
|
|
|
cnode_ = node->cast<CNodePtr>();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Reset() {
|
|
|
|
|
vnode_ = nullptr;
|
|
|
|
|
cnode_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
AnfNodePtr vnode_;
|
|
|
|
|
CNodePtr cnode_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ArithmeticSimplify {
|
|
|
|
|
public:
|
|
|
|
|
ArithmeticSimplify()
|
|
|
|
@ -186,12 +235,14 @@ class ArithmeticSimplify {
|
|
|
|
|
add_by_zero_(),
|
|
|
|
|
tensor_add_by_zero_(),
|
|
|
|
|
identity_(prim::kPrimIdentity),
|
|
|
|
|
opt_update_zero_tensor_() {
|
|
|
|
|
opt_update_zero_tensor_(),
|
|
|
|
|
constant_duplicate_mul_() {
|
|
|
|
|
eliminaters_.emplace_back(multiply_by_zero_or_one_);
|
|
|
|
|
eliminaters_.emplace_back(add_by_zero_);
|
|
|
|
|
eliminaters_.emplace_back(tensor_add_by_zero_);
|
|
|
|
|
eliminaters_.emplace_back(identity_);
|
|
|
|
|
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
|
|
|
|
eliminaters_.emplace_back(constant_duplicate_mul_);
|
|
|
|
|
}
|
|
|
|
|
~ArithmeticSimplify() = default;
|
|
|
|
|
|
|
|
|
@ -212,6 +263,7 @@ class ArithmeticSimplify {
|
|
|
|
|
TensorAddByZero tensor_add_by_zero_;
|
|
|
|
|
PrimEliminater identity_;
|
|
|
|
|
OptUpdateZeroTensor opt_update_zero_tensor_;
|
|
|
|
|
ConstantDuplicateMul constant_duplicate_mul_;
|
|
|
|
|
std::vector<TransformFuncType> eliminaters_{};
|
|
|
|
|
};
|
|
|
|
|
} // namespace irpass
|
|
|
|
|