!2923 Update arithmetic simplify to use Pattern Matcher

Merge pull request !2923 from Giancarlo/update_arith_simplify
pull/2923/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b5e8e2419e

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -22,158 +22,14 @@
#include <vector> #include <vector>
#include "ir/optimizer_caller.h" #include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h" #include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
class MultiplyByZeroOrOne : public AnfVisitor {
public:
MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {}
~MultiplyByZeroOrOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false}, is_one_{false};
ValuePtr zero_, one_;
AnfNodePtr x_{nullptr};
};
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
class CheckTensorConstant {
public:
explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
~CheckTensorConstant() = default;
bool IsTensorConstant(const ValuePtr &value);
bool IsTensorScalarConstant(const ValuePtr &value);
private:
int check_value_;
};
class TensorMultiplyBase : public AnfVisitor {
protected:
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false);
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr);
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class TensorMultiplyByZero : public TensorMultiplyBase {
public:
TensorMultiplyByZero() : zero_(MakeValue(0)) {}
~TensorMultiplyByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false};
ValuePtr zero_;
};
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByOne : public TensorMultiplyBase {
public:
TensorMultiplyByOne() {}
~TensorMultiplyByOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_one_{false};
};
// {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X}
class AddByZero : public AnfVisitor {
public:
AddByZero() : zero_(MakeValue(0)) {}
~AddByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Reset();
private:
bool is_zero_{false};
ValuePtr zero_;
AnfNodePtr x_{nullptr};
};
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
class TensorAddByZero : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false};
AnfNodePtr x_{nullptr};
};
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
class OptUpdateZeroTensor : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class ConstantDuplicateMul : public AnfVisitor {
public:
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename T>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size);
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3);
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Reset();
private:
AnfNodePtr vnode_;
AnfNodePtr c_p_node_;
};
class PowerOneEliminate : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// grad = AllReduce(grad) / worker_number // grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy // grad = grad + weight * decy
// -> // ->
@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
class ArithmeticSimplify : public OptimizerCaller { class ArithmeticSimplify : public OptimizerCaller {
public: public:
ArithmeticSimplify() AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_);
eliminaters_.emplace_back(power_one_);
}
~ArithmeticSimplify() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
private:
OptimizerCallerPtr multiply_by_zero_or_one_;
OptimizerCallerPtr tensor_multiply_by_one_;
OptimizerCallerPtr add_by_zero_;
OptimizerCallerPtr tensor_add_by_zero_;
OptimizerCallerPtr identity_;
OptimizerCallerPtr opt_update_zero_tensor_;
OptimizerCallerPtr constant_duplicate_mul_;
OptimizerCallerPtr power_one_;
std::vector<OptimizerCallerPtr> eliminaters_{};
}; };
// Arithmetic Simplifications should be done after step_parallel. // Arithmetic Simplifications should be done after step_parallel.
@ -242,17 +66,9 @@ class ArithmeticSimplify : public OptimizerCaller {
// ArithmeticSimplify and deferred until step_parallel. // ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 : public OptimizerCaller { class ArithmeticSimplify2 : public OptimizerCaller {
public: public:
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
private:
OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<OptimizerCallerPtr> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -25,10 +25,8 @@
#include "ir/optimizer_caller.h" #include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h" #include "ir/pattern_matcher.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h" #include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {

@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common {
}; };
void SetUp() { void SetUp() {
elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd); elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R); elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P); idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q); Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);

Loading…
Cancel
Save