|
|
@ -22,158 +22,14 @@
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
|
|
#include "ir/optimizer_caller.h"
|
|
|
|
#include "ir/optimizer_caller.h"
|
|
|
|
|
|
|
|
#include "ir/pattern_matcher.h"
|
|
|
|
#include "ir/visitor.h"
|
|
|
|
#include "ir/visitor.h"
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
|
|
|
#include "optimizer/irpass.h"
|
|
|
|
#include "optimizer/irpass.h"
|
|
|
|
#include "optimizer/irpass/prim_eliminate.h"
|
|
|
|
#include "optimizer/irpass/prim_eliminate.h"
|
|
|
|
#include "optimizer/optimizer.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace opt {
|
|
|
|
namespace opt {
|
|
|
|
namespace irpass {
|
|
|
|
namespace irpass {
|
|
|
|
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
|
|
|
|
|
|
|
|
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
|
|
|
|
|
|
|
|
class MultiplyByZeroOrOne : public AnfVisitor {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {}
|
|
|
|
|
|
|
|
~MultiplyByZeroOrOne() override = default;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override;
|
|
|
|
|
|
|
|
void Reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
bool is_zero_{false}, is_one_{false};
|
|
|
|
|
|
|
|
ValuePtr zero_, one_;
|
|
|
|
|
|
|
|
AnfNodePtr x_{nullptr};
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Support class used for checking if all values of a Tensor are equal `check_value_`
|
|
|
|
|
|
|
|
// Supported data types: double, float/float32, int/int32
|
|
|
|
|
|
|
|
class CheckTensorConstant {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
|
|
|
|
|
|
|
|
~CheckTensorConstant() = default;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool IsTensorConstant(const ValuePtr &value);
|
|
|
|
|
|
|
|
bool IsTensorScalarConstant(const ValuePtr &value);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
int check_value_;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorMultiplyBase : public AnfVisitor {
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Make a new tensor (when possible) with the same shape as of `node`
|
|
|
|
|
|
|
|
// If x is nullptr then fill new tensor will "0"
|
|
|
|
|
|
|
|
// If x is a tensor with empty shape then fill new tensor with the single value of x
|
|
|
|
|
|
|
|
// If x is a tensor with same shape as `node` then return x as result
|
|
|
|
|
|
|
|
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr x_{nullptr};
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
|
|
|
|
|
|
|
|
class TensorMultiplyByZero : public TensorMultiplyBase {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
TensorMultiplyByZero() : zero_(MakeValue(0)) {}
|
|
|
|
|
|
|
|
~TensorMultiplyByZero() override = default;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override;
|
|
|
|
|
|
|
|
void Reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
bool is_zero_{false};
|
|
|
|
|
|
|
|
ValuePtr zero_;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
|
|
|
|
|
|
|
|
class TensorMultiplyByOne : public TensorMultiplyBase {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
TensorMultiplyByOne() {}
|
|
|
|
|
|
|
|
~TensorMultiplyByOne() override = default;
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override;
|
|
|
|
|
|
|
|
void Reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
bool is_one_{false};
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimScalarAdd, X, 0}
|
|
|
|
|
|
|
|
// {prim::kPrimScalarAdd, 0, X}
|
|
|
|
|
|
|
|
class AddByZero : public AnfVisitor {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
AddByZero() : zero_(MakeValue(0)) {}
|
|
|
|
|
|
|
|
~AddByZero() override = default;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
void Reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
bool is_zero_{false};
|
|
|
|
|
|
|
|
ValuePtr zero_;
|
|
|
|
|
|
|
|
AnfNodePtr x_{nullptr};
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
|
|
|
|
|
|
|
|
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
|
|
|
|
|
|
|
|
class TensorAddByZero : public AnfVisitor {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override;
|
|
|
|
|
|
|
|
void Reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
bool is_zero_{false};
|
|
|
|
|
|
|
|
AnfNodePtr x_{nullptr};
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
|
|
|
|
|
|
|
|
class OptUpdateZeroTensor : public AnfVisitor {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
|
|
|
|
|
|
|
|
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
|
|
|
|
|
|
|
|
class ConstantDuplicateMul : public AnfVisitor {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
// Support function to multiply two constant tensors: partially support broadcasting shapes
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
|
|
|
|
|
|
|
|
int out_data_size);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
void Reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
AnfNodePtr vnode_;
|
|
|
|
|
|
|
|
AnfNodePtr c_p_node_;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PowerOneEliminate : public AnfVisitor {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// grad = AllReduce(grad) / worker_number
|
|
|
|
// grad = AllReduce(grad) / worker_number
|
|
|
|
// grad = grad + weight * decy
|
|
|
|
// grad = grad + weight * decy
|
|
|
|
// ->
|
|
|
|
// ->
|
|
|
@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
|
|
|
|
|
|
|
|
|
|
|
|
class ArithmeticSimplify : public OptimizerCaller {
|
|
|
|
class ArithmeticSimplify : public OptimizerCaller {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
ArithmeticSimplify()
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
|
|
|
|
|
|
|
|
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
|
|
|
|
|
|
|
|
add_by_zero_(std::make_shared<AddByZero>()),
|
|
|
|
|
|
|
|
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
|
|
|
|
|
|
|
|
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
|
|
|
|
|
|
|
|
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
|
|
|
|
|
|
|
|
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
|
|
|
|
|
|
|
|
power_one_(std::make_shared<PowerOneEliminate>()) {
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(multiply_by_zero_or_one_);
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(tensor_multiply_by_one_);
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(add_by_zero_);
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(tensor_add_by_zero_);
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(identity_);
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(constant_duplicate_mul_);
|
|
|
|
|
|
|
|
eliminaters_.emplace_back(power_one_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
~ArithmeticSimplify() = default;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
OptimizerCallerPtr multiply_by_zero_or_one_;
|
|
|
|
|
|
|
|
OptimizerCallerPtr tensor_multiply_by_one_;
|
|
|
|
|
|
|
|
OptimizerCallerPtr add_by_zero_;
|
|
|
|
|
|
|
|
OptimizerCallerPtr tensor_add_by_zero_;
|
|
|
|
|
|
|
|
OptimizerCallerPtr identity_;
|
|
|
|
|
|
|
|
OptimizerCallerPtr opt_update_zero_tensor_;
|
|
|
|
|
|
|
|
OptimizerCallerPtr constant_duplicate_mul_;
|
|
|
|
|
|
|
|
OptimizerCallerPtr power_one_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<OptimizerCallerPtr> eliminaters_{};
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// Arithmetic Simplifications should be done after step_parallel.
|
|
|
|
// Arithmetic Simplifications should be done after step_parallel.
|
|
|
@ -242,17 +66,9 @@ class ArithmeticSimplify : public OptimizerCaller {
|
|
|
|
// ArithmeticSimplify and deferred until step_parallel.
|
|
|
|
// ArithmeticSimplify and deferred until step_parallel.
|
|
|
|
class ArithmeticSimplify2 : public OptimizerCaller {
|
|
|
|
class ArithmeticSimplify2 : public OptimizerCaller {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
|
|
|
eliminaters_.emplace_back(tensor_multiply_by_zero_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
~ArithmeticSimplify2() = default;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
OptimizerCallerPtr tensor_multiply_by_zero_;
|
|
|
|
|
|
|
|
std::vector<OptimizerCallerPtr> eliminaters_{};
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace irpass
|
|
|
|
} // namespace irpass
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|
|
|
|