From 6bd66e79531ce560b50f3a80fbd2a32226054b50 Mon Sep 17 00:00:00 2001 From: Hoai Linh Tran Date: Thu, 2 Jul 2020 15:29:12 -0400 Subject: [PATCH] Fix memcpy calls. Add ut tests for arithmetic_simplify. Split long arithmetic_simplify.h to arithmetic_simplify.cc Code checking --- .../optimizer/irpass/arithmetic_simplify.cc | 680 ++++++++++++++++++ .../optimizer/irpass/arithmetic_simplify.h | 639 ++-------------- .../python/pynative_mode/test_framstruct.py | 116 +++ 3 files changed, 841 insertions(+), 594 deletions(-) create mode 100644 mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc new file mode 100644 index 0000000000..b111a6b67a --- /dev/null +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc @@ -0,0 +1,680 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "optimizer/irpass/arithmetic_simplify.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} +// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} +AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarMul)(node); + + if (is_zero_) { + return NewValueNode(zero_); + } + if (is_one_) { + return x_; + } + return nullptr; +} + +void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { + if (is_one_ || node->isa()) { + x_ = node; + return; + } + + AnfVisitor::Visit(node); + if (!is_one_) { + x_ = node; + } +} + +void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (*value == *zero_) { + is_zero_ = true; + } else if (*value == *one_) { + is_one_ = true; + } +} + +void MultiplyByZeroOrOne::Reset() { + x_ = nullptr; + is_one_ = false; + is_zero_ = false; +} + +// Support class used for checking if all values of a Tensor are equal `check_value_` +// Supported data types: double, float/float32, int/int32 +bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > FLT_EPSILON) { + return false; + } + } + return true; + } else if (tensor_type == TypeId::kNumberTypeFloat64) { + double *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > DBL_EPSILON) { + return false; + } + } + return true; + } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { + int *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] != check_value_) { + return false; + } + } + return true; + } + // input Data Types is not supported + return false; +} + +bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { + return false; + } + return IsTensorConstant(value); +} + +void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { + if (!node->isa()) { + return nullptr; + } + + auto value = node->cast()->value(); + + if (!value->isa()) { + return nullptr; + } + + tensor::TensorPtr tensor_ptr = dyn_cast(value); + return tensor_ptr->data_c(); +} + +// Make a new tensor (when possible) with the same shape as of `node` +// If x is nullptr then fill new tensor will "0" +// If x is a tensor with empty shape then fill new tensor with the single value of x +// If x is a tensor with same shape as `node` then return x as result +AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + if (x == nullptr) { + std::memset(data, 0, mem_size); + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + // x is not nullptr + if (x->isa()) { + if ((x->abstract() == nullptr) || !x->abstract()->isa()) { + return nullptr; + } + auto x_abstract = x->abstract()->cast(); + std::vector x_shape = x_abstract->shape()->shape(); + + if (x_shape != tensor_shape) { + return nullptr; + } + return x; + } + + if (!x->isa()) { + return nullptr; + } + auto x_value = x->cast()->value(); + if (!x_value->isa()) { + return nullptr; + } + + auto x_tensor_ptr = dyn_cast(x_value); + + if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { + return nullptr; + } + char *source_data = reinterpret_cast(GetPointerToTensorData(x)); + if (x_tensor_ptr->DataSize() == 1) { + for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { + memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); + } + } else { + memcpy(data, source_data, mem_size); + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_zero_) { + if (x_->func_graph() != node->func_graph()) { + return nullptr; + } + return NewTensorFilledWithData(node); + } + return nullptr; +} + +void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { + if (is_zero_) { + x_ = node; + return; + } + + if (IsParam(node)) { + x_ = node; + return; + } + + if (IsCNode(node)) { + CNodePtr cnode = node->cast(); + if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { + is_zero_ = true; + return; + } + x_ = node; + return; + } + auto value = node->cast()->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_one_) { + return NewTensorFilledWithData(node, x_); + } + return nullptr; +} + +void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { + if (is_one_) { + x_ = node; + return; + } + + if (IsParam(node) || IsCNode(node)) { + x_ = node; + return; + } + + auto value = node->cast()->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByOne::Reset() { + x_ = nullptr; + is_one_ = false; +} + +// {prim::kPrimScalarAdd, X, 0} +// {prim::kPrimScalarAdd, 0, X} +AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarAdd)(node); + + if (is_zero_) { + return x_; + } + return nullptr; +} + +void AddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && + ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void AddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, +// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} +AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimTensorAdd)(node); + + if (is_zero_) { + return x_; + } + return nullptr; +} + +void TensorAddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void TensorAddByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } +} + +void TensorAddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} +AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { + return nullptr; + } + + // {PrimMomentum, {...}, Y, Z, Xs} + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { + return nullptr; + } + auto y = inputs[2]; + auto z = inputs[3]; + + // {kPrimZerosLike, X} + if (inputs[1]->cast()->size() != 2) { + return nullptr; + } + + // {prim::kPrimMakeTuple, Z, Y} + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); +} + +// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +// Support function to multiply two constant tensors: partially support broadcasting shapes +template +void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, + void **out_data, int out_data_size) { + T *data_1 = reinterpret_cast(in_data_1); + T *data_2 = reinterpret_cast(in_data_2); + T *data_out = new T[out_data_size]; + + if (in_data_1_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[i]; + } + } + if (in_data_2_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[i]; + } + } + *out_data = reinterpret_cast(data_out); + return; +} + +AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, + const AnfNodePtr &node_3) { + if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || + (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { + return nullptr; + } + + auto value_1 = GetValueNode(vnode_1); + auto value_2 = GetValueNode(vnode_2); + + if (!value_1->isa() || !value_2->isa()) { + return nullptr; + } + + auto tensor_ptr_1 = dyn_cast(value_1); + auto tensor_ptr_2 = dyn_cast(value_2); + + auto tensor_1_abstract = vnode_1->abstract()->cast(); + auto tensor_2_abstract = vnode_1->abstract()->cast(); + auto tensor_3_abstract = node_3->abstract()->cast(); + + TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); + TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); + TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); + + if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || + (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { + return nullptr; + } + + std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + + int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + + if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { + return nullptr; + } + if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { + return nullptr; + } + + void *data_out; + + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), + &data_out, data_out_size); + } else { + if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + // Un-support data types + return nullptr; + } + } + } + + auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); + size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + memcpy(data, data_out, mem_size); + + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + // {prim::kPrimMul, Tensor1, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + + if (!IsCNode(c_p_node_)) { + return nullptr; + } + + auto tensor1 = vnode_; + auto mul = c_p_node_->cast(); + + Reset(); + // {prim::kPrimMul, Tensor2, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + auto tensor2 = vnode_; + auto c_p_node = c_p_node_; + + auto PrimMul = GetValueNode(mul->input(0)); + auto fg = node->func_graph(); + + auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); + if (new_mul_tensor == nullptr) { + auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); + return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); + } + return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); +} + +void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { + if (IsValueNode(node)) { + vnode_ = node; + } + + if (IsCNode(node) || IsParam(node)) { + c_p_node_ = node; + } +} + +void ConstantDuplicateMul::Reset() { + vnode_ = nullptr; + c_p_node_ = nullptr; +} + +AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (!IsValueNode(inputs[2])) { + return nullptr; + } + auto scalar = GetValueNode(inputs[2]); + if (scalar->isa() && GetValue(scalar) == 1.0) { + return inputs[1]; + } else if (scalar->isa() && GetValue(scalar) == 1) { + return inputs[1]; + } + return nullptr; +} + +// grad = AllReduce(grad) / worker_number +// grad = grad + weight * decy +// -> +// grad = grad + weight * decy +// grad = AllReduce(grad) / worker_number +// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> +// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} +AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + // {prim::kPrimAddN, Zs} + if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { + return nullptr; + } + auto addn = node->cast(); + if (addn->size() != 2) { + return nullptr; + } + AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); + if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { + return nullptr; + } + auto addn_maketuple = addn->input(1); + + auto fg = all_reduce_fg_; + // addn inputs cross the graph, make the inputs same as allreduce node. + if (z_->isa() && fg != z_->func_graph()) { + auto cnode_z = z_->cast(); + z_ = NewCNode(cnode_z->inputs(), fg); + } + + auto addn_op_node = addn->input(0); + auto make_tuple_op_node = addn->input(1)->cast()->input(0); + + AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); + AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); + AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); + AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); + ProcessDependEdge(fg, addn_maketuple, all_reduce); + return mul; +} + +void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, + const AnfNodePtr &new_node) { + // If has dynamic loss scale. + auto &users_map = fg->manager()->node_users(); + auto it = users_map.find(mul_cnode_); + if (it != users_map.end()) { + auto users = it->second; + for (auto &user_pair : users) { + auto node = user_pair.first; + if (node != addn_maketuple) { + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + fg->manager()->SetEdge(node, user_pair.second, new_node); + } + } + } + } +} + +void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { + if (level_ == 0) { + level_ = 1; + is_reduce_match_ = false; + // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} + AnfVisitor::Match(prim::kPrimMul)(node); + level_ = 0; + if (is_reduce_match_) { + mul_ = node->cast()->input(0); + mul_cnode_ = node->cast(); + y_ = tmp_; + } else { + z_ = node; + } + } + + if (level_ == 1) { + // {prim::kPrimAllReduce, X} + if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { + auto cnode = node->cast(); + if (cnode->size() > 1) { + all_reduce_ = cnode->input(0); + x_ = cnode->input(1); + is_reduce_match_ = true; + all_reduce_fg_ = cnode->func_graph(); + } + } else { + tmp_ = node; + } + } +} + +void AdjustAllReduceMulAdd::Reset() { + level_ = 0; + is_reduce_match_ = false; + x_ = nullptr; + y_ = nullptr; + z_ = nullptr; + tmp_ = nullptr; + all_reduce_fg_ = nullptr; +} + +AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} + +AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 940cdde6d2..f4bdb0d655 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -38,45 +38,11 @@ class MultiplyByZeroOrOne : public AnfVisitor { MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} ~MultiplyByZeroOrOne() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimScalarMul)(node); - - if (is_zero_) { - return NewValueNode(zero_); - } - if (is_one_) { - return x_; - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (is_one_ || node->isa()) { - x_ = node; - return; - } - - AnfVisitor::Visit(node); - if (!is_one_) { - x_ = node; - } - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override { - auto value = vnode->value(); - if (*value == *zero_) { - is_zero_ = true; - } else if (*value == *one_) { - is_one_ = true; - } - } - - void Reset() { - x_ = nullptr; - is_one_ = false; - is_zero_ = false; - } + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); private: bool is_zero_{false}, is_one_{false}; @@ -90,51 +56,9 @@ class CheckTensorConstant { public: explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} ~CheckTensorConstant() = default; - bool IsTensorConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > FLT_EPSILON) { - return false; - } - } - return true; - } else if (tensor_type == TypeId::kNumberTypeFloat64) { - double *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > DBL_EPSILON) { - return false; - } - } - return true; - } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { - int *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (data2[i] != check_value_) { - return false; - } - } - return true; - } - // Un-support Data Types - return false; - } - bool IsTensorScalarConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { - return false; - } - return IsTensorConstant(value); - } + bool IsTensorConstant(const ValuePtr &value); + bool IsTensorScalarConstant(const ValuePtr &value); private: int check_value_; @@ -142,83 +66,13 @@ class CheckTensorConstant { class TensorMultiplyBase : public AnfVisitor { protected: - void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) { - if (!node->isa()) { - return nullptr; - } - - auto value = node->cast()->value(); - - if (!value->isa()) { - return nullptr; - } - - tensor::TensorPtr tensor_ptr = dyn_cast(value); - return tensor_ptr->data_c(); - } + void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); // Make a new tensor (when possible) with the same shape as of `node` // If x is nullptr then fill new tensor will "0" // If x is a tensor with empty shape then fill new tensor with the single value of x // If x is a tensor with same shape as `node` then return x as result - AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) { - if ((node->abstract() == nullptr) || !node->abstract()->isa()) { - return nullptr; - } - - auto tensor_abstract = node->abstract()->cast(); - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - if (x == nullptr) { - std::memset(data, 0, mem_size); - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - // x is not nullptr - if (x->isa()) { - if ((x->abstract() == nullptr) || !x->abstract()->isa()) { - return nullptr; - } - auto x_abstract = x->abstract()->cast(); - std::vector x_shape = x_abstract->shape()->shape(); - - if (x_shape != tensor_shape) { - return nullptr; - } - return x; - } - - if (!x->isa()) { - return nullptr; - } - auto x_value = x->cast()->value(); - if (!x_value->isa()) { - return nullptr; - } - - auto x_tensor_ptr = dyn_cast(x_value); - - if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { - return nullptr; - } - char *source_data = reinterpret_cast(GetPointerToTensorData(x)); - if (x_tensor_ptr->DataSize() == 1) { - for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { - memcpy(source_data, data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr)); - } - } else { - memcpy(source_data, data, mem_size); - } - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); AnfNodePtr x_{nullptr}; }; @@ -228,59 +82,12 @@ class TensorMultiplyByZero : public TensorMultiplyBase { public: TensorMultiplyByZero() : zero_(MakeValue(0)) {} ~TensorMultiplyByZero() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_zero_) { - if (x_->func_graph() != node->func_graph()) { - return nullptr; - } - return NewTensorFilledWithData(node); - } - return nullptr; - } - void Visit(const AnfNodePtr &node) override { - if (is_zero_) { - x_ = node; - return; - } - - if (IsParam(node)) { - x_ = node; - return; - } - - if (IsCNode(node)) { - CNodePtr cnode = node->cast(); - if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { - is_zero_ = true; - return; - } - x_ = node; - return; - } - auto value = node->cast()->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = node; - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = vnode; - } - void Reset() { - x_ = nullptr; - is_zero_ = false; - } + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); private: bool is_zero_{false}; @@ -292,47 +99,11 @@ class TensorMultiplyByOne : public TensorMultiplyBase { public: TensorMultiplyByOne() {} ~TensorMultiplyByOne() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_one_) { - return NewTensorFilledWithData(node, x_); - } - return nullptr; - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const AnfNodePtr &node) override { - if (is_one_) { - x_ = node; - return; - } - - if (IsParam(node) || IsCNode(node)) { - x_ = node; - return; - } - - auto value = node->cast()->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = node; - } - - void Visit(const ValueNodePtr &vnode) override { - auto value = vnode->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = vnode; - } - void Reset() { - x_ = nullptr; - is_one_ = false; - } + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); private: bool is_one_{false}; @@ -345,30 +116,10 @@ class AddByZero : public AnfVisitor { AddByZero() : zero_(MakeValue(0)) {} ~AddByZero() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimScalarAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (node->isa() && - ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { - is_zero_ = true; - return; - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - x_ = node; - } - - void Reset() { - x_ = nullptr; - is_zero_ = false; - } + void Visit(const AnfNodePtr &node) override; + void Reset(); private: bool is_zero_{false}; @@ -380,37 +131,11 @@ class AddByZero : public AnfVisitor { // {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} class TensorAddByZero : public AnfVisitor { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTensorAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const AnfNodePtr &node) override { - if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { - is_zero_ = true; - return; - } - - x_ = node; - } - - void Visit(const ValueNodePtr &vnode) override { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - } - - void Reset() { - x_ = nullptr; - is_zero_ = false; - } + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); private: bool is_zero_{false}; @@ -420,27 +145,7 @@ class TensorAddByZero : public AnfVisitor { // {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} class OptUpdateZeroTensor : public AnfVisitor { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { - return nullptr; - } - - // {PrimMomentum, {...}, Y, Z, Xs} - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { - return nullptr; - } - auto y = inputs[2]; - auto z = inputs[3]; - - // {kPrimZerosLike, X} - if (inputs[1]->cast()->size() != 2) { - return nullptr; - } - - // {prim::kPrimMakeTuple, Z, Y} - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; }; // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> @@ -450,156 +155,14 @@ class ConstantDuplicateMul : public AnfVisitor { // Support function to multiply two constant tensors: partially support broadcasting shapes template void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, - int out_data_size) { - T *data_1 = reinterpret_cast(in_data_1); - T *data_2 = reinterpret_cast(in_data_2); - T *data_out = new T[out_data_size]; - - if (in_data_1_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[i]; - } - } - if (in_data_2_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[i]; - } - } - *out_data = reinterpret_cast(data_out); - return; - } - - AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) { - if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || - (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { - return nullptr; - } - - auto value_1 = GetValueNode(vnode_1); - auto value_2 = GetValueNode(vnode_2); - - if (!value_1->isa() || !value_2->isa()) { - return nullptr; - } - - auto tensor_ptr_1 = dyn_cast(value_1); - auto tensor_ptr_2 = dyn_cast(value_2); - - auto tensor_1_abstract = vnode_1->abstract()->cast(); - auto tensor_2_abstract = vnode_1->abstract()->cast(); - auto tensor_3_abstract = node_3->abstract()->cast(); - - TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); - TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); - TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); - - if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || - (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { - return nullptr; - } - - std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); - - int data_out_size = 1; - for (auto it : tensor_out_shape) { - data_out_size *= it; - } - if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { - return nullptr; - } - if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { - return nullptr; - } - - void *data_out; - - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - // Un-support data types - return nullptr; - } - } - } - - auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); - size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - memcpy(data, data_out, mem_size); - - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } + int out_data_size); - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - // {prim::kPrimMul, Tensor1, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - - if (!IsCNode(c_p_node_)) { - return nullptr; - } - - auto tensor1 = vnode_; - auto mul = c_p_node_->cast(); - - Reset(); - // {prim::kPrimMul, Tensor2, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - auto tensor2 = vnode_; - auto c_p_node = c_p_node_; - - auto PrimMul = GetValueNode(mul->input(0)); - auto fg = node->func_graph(); - - auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); - if (new_mul_tensor == nullptr) { - auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); - return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); - } - return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); - } + AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); - void Visit(const AnfNodePtr &node) override { - if (IsValueNode(node)) { - vnode_ = node; - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - if (IsCNode(node) || IsParam(node)) { - c_p_node_ = node; - } - } - - void Reset() { - vnode_ = nullptr; - c_p_node_ = nullptr; - } + void Visit(const AnfNodePtr &node) override; + void Reset(); private: AnfNodePtr vnode_; @@ -608,23 +171,7 @@ class ConstantDuplicateMul : public AnfVisitor { class PowerOneEliminate : public AnfVisitor { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (!IsValueNode(inputs[2])) { - return nullptr; - } - auto scalar = GetValueNode(inputs[2]); - if (scalar->isa() && GetValue(scalar) == 1.0) { - return inputs[1]; - } else if (scalar->isa() && GetValue(scalar) == 1) { - return inputs[1]; - } - return nullptr; - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; }; // grad = AllReduce(grad) / worker_number @@ -637,96 +184,11 @@ class PowerOneEliminate : public AnfVisitor { // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} class AdjustAllReduceMulAdd : public AnfVisitor { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - // {prim::kPrimAddN, Zs} - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - return nullptr; - } - auto addn = node->cast(); - if (addn->size() != 2) { - return nullptr; - } - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); - if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { - return nullptr; - } - auto addn_maketuple = addn->input(1); - - auto fg = all_reduce_fg_; - // addn inputs cross the graph, make the inputs same as allreduce node. - if (z_->isa() && fg != z_->func_graph()) { - auto cnode_z = z_->cast(); - z_ = NewCNode(cnode_z->inputs(), fg); - } - - auto addn_op_node = addn->input(0); - auto make_tuple_op_node = addn->input(1)->cast()->input(0); - - AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); - AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); - AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); - AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); - ProcessDependEdge(fg, addn_maketuple, all_reduce); - return mul; - } - void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) { - // If has dynamic loss scale. - auto &users_map = fg->manager()->node_users(); - auto it = users_map.find(mul_cnode_); - if (it != users_map.end()) { - auto users = it->second; - for (auto &user_pair : users) { - auto node = user_pair.first; - if (node != addn_maketuple) { - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - fg->manager()->SetEdge(node, user_pair.second, new_node); - } - } - } - } - } - void Visit(const AnfNodePtr &node) override { - if (level_ == 0) { - level_ = 1; - is_reduce_match_ = false; - // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} - AnfVisitor::Match(prim::kPrimMul)(node); - level_ = 0; - if (is_reduce_match_) { - mul_ = node->cast()->input(0); - mul_cnode_ = node->cast(); - y_ = tmp_; - } else { - z_ = node; - } - } - - if (level_ == 1) { - // {prim::kPrimAllReduce, X} - if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { - auto cnode = node->cast(); - if (cnode->size() > 1) { - all_reduce_ = cnode->input(0); - x_ = cnode->input(1); - is_reduce_match_ = true; - all_reduce_fg_ = cnode->func_graph(); - } - } else { - tmp_ = node; - } - } - } + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Reset() { - level_ = 0; - is_reduce_match_ = false; - x_ = nullptr; - y_ = nullptr; - z_ = nullptr; - tmp_ = nullptr; - all_reduce_fg_ = nullptr; - } + void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); + void Visit(const AnfNodePtr &node) override; + void Reset(); private: int level_{0}; @@ -758,20 +220,18 @@ class ArithmeticSimplify : public OptimizerCaller { } ~ArithmeticSimplify() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; private: - OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_, - opt_update_zero_tensor_, constant_duplicate_mul_, power_one_; + OptimizerCallerPtr multiply_by_zero_or_one_; + OptimizerCallerPtr tensor_multiply_by_one_; + OptimizerCallerPtr add_by_zero_; + OptimizerCallerPtr tensor_add_by_zero_; + OptimizerCallerPtr identity_; + OptimizerCallerPtr opt_update_zero_tensor_; + OptimizerCallerPtr constant_duplicate_mul_; + OptimizerCallerPtr power_one_; + std::vector eliminaters_{}; }; @@ -787,16 +247,7 @@ class ArithmeticSimplify2 : public OptimizerCaller { } ~ArithmeticSimplify2() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; private: OptimizerCallerPtr tensor_multiply_by_zero_; diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 99d792671a..39a4c97ab9 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -549,6 +549,122 @@ def test_zeros(): assert res == Tensor(np.zeros([2, 3]).astype(np.int32)) +@ms_function +def arithmetic_simplify_01(x, y): + """ arithmetic_simplify_01 """ + return C.zeros_like(x) * y + + +def test_arithmetic_simplify_01(): + """ test_arithmetic_simplify_01 """ + x = Tensor(np.ones([2, 3]).astype(np.int32)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + res = arithmetic_simplify_01(x, y) + expect = np.zeros([2, 3]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + +@ms_function +def arithmetic_simplify_02(x, y): + """ arithmetic_simplify_02 """ + return C.ones_like(x) * y + + +def test_arithmetic_simplify_02(): + """ test_arithmetic_simplify_02 """ + x = Tensor(np.ones([2, 3]).astype(np.int32)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + res = arithmetic_simplify_02(x, y) + expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + +@ms_function +def arithmetic_simplify_03(x, y): + """ arithmetic_simplify_03 """ + return x * C.ones_like(y) + + +def test_arithmetic_simplify_03(): + """ test_arithmetic_simplify_03 """ + x = Tensor(np.ones([2, 3]).astype(np.int32)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + res = arithmetic_simplify_03(x, y) + expect = np.ones([2, 3]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + +@ms_function +def arithmetic_simplify_04(x): + """ arithmetic_simplify_04 """ + return x + 0 + + +def test_arithmetic_simplify_04(): + """ test_arithmetic_simplify_04 """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + res = arithmetic_simplify_04(x) + expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + +@ms_function +def arithmetic_simplify_05(x): + """ arithmetic_simplify_05 """ + return x * 1 + + +def test_arithmetic_simplify_05(): + """ test_arithmetic_simplify_05 """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + res = arithmetic_simplify_05(x) + expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + +@ms_function +def arithmetic_simplify_06(x): + """ arithmetic_simplify_06 """ + return x * 2 * 5 + + +def test_arithmetic_simplify_06(): + """ test_arithmetic_simplify_06 """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + res = arithmetic_simplify_06(x) + expect = np.array([[10, 20, 30], [40, 50, 60]]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + +@ms_function +def arithmetic_simplify_07(x): + """ arithmetic_simplify_07 """ + return (x + 1) * 2 * 5 + + +def test_arithmetic_simplify_07(): + """ test_arithmetic_simplify_07 """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + res = arithmetic_simplify_07(x) + expect = np.array([[20, 30, 40], [50, 60, 70]]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + +@ms_function +def arithmetic_simplify_08(x, y): + """ arithmetic_simplify_08 """ + return 1 * x * 1 * 1 + 1 * 0 * 1 + 0 + y * 1 + + +def test_arithmetic_simplify_08(): + """ test_arithmetic_simplify_08 """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) + y = Tensor(np.ones([2, 3]).astype(np.int32)) + res = arithmetic_simplify_08(x, y) + expect = np.array([[2, 3, 4], [5, 6, 7]]).astype(np.int32) + assert np.all(res.asnumpy() == expect) + + def test_ScalarGradChecker(): """ test_ScalarGradChecker """