From 6cdc86383bae5cc739d3a1af0dd21595554f6cd2 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Thu, 30 Jul 2020 11:49:33 +0800 Subject: [PATCH] add anf pass --- mindspore/lite/src/gllo/common/utils.cc | 20 ++- mindspore/lite/src/gllo/common/utils.h | 3 + .../src/gllo/fusion/conv_activation_fusion.cc | 64 +++++++++ .../src/gllo/fusion/conv_activation_fusion.h | 38 ++++++ .../src/gllo/fusion/conv_biasadd_fusion.cc | 4 +- .../lite/src/gllo/fusion/conv_scale_fusion.cc | 126 ++++++++++++++++++ .../lite/src/gllo/fusion/conv_scale_fusion.h | 40 ++++++ mindspore/lite/test/CMakeLists.txt | 2 + mindspore/lite/tools/converter/CMakeLists.txt | 2 + 9 files changed, 295 insertions(+), 4 deletions(-) create mode 100644 mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc create mode 100644 mindspore/lite/src/gllo/fusion/conv_activation_fusion.h create mode 100644 mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc create mode 100644 mindspore/lite/src/gllo/fusion/conv_scale_fusion.h diff --git a/mindspore/lite/src/gllo/common/utils.cc b/mindspore/lite/src/gllo/common/utils.cc index 69d3e58a25..9433a04918 100644 --- a/mindspore/lite/src/gllo/common/utils.cc +++ b/mindspore/lite/src/gllo/common/utils.cc @@ -16,9 +16,10 @@ #include #include #include "src/gllo/common/utils.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" +#include "src/ir/primitive_t_value.h" #include "frontend/operator/ops.h" +using PrimitiveTValuePtr = std::shared_ptr; namespace mindspore { namespace opt { @@ -74,7 +75,11 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { } } } - + if (a.m_ptr->isa()) { + auto a_value_node_ptr = a.m_ptr->cast(); + auto b_value_node_ptr = b.m_ptr->cast(); + return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; + } return a == b; } @@ -203,5 +208,16 @@ void CheckInputSize(const CNodePtr &node, const int size) { } } +schema::PrimitiveType GetCNodeType(const CNodePtr &node) { + auto value_primitive = node->input(0); + auto value_node = value_primitive->cast(); + MS_ASSERT(value_node != nullptr); + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + auto primitive = value->cast(); + MS_ASSERT(primitive != nullptr); + return primitive->GetPrimitiveT()->value.type; +} + } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/src/gllo/common/utils.h b/mindspore/lite/src/gllo/common/utils.h index ffd57de618..3147bd329f 100644 --- a/mindspore/lite/src/gllo/common/utils.h +++ b/mindspore/lite/src/gllo/common/utils.h @@ -21,6 +21,7 @@ #include "ir/func_graph.h" #include "src/common/utils.h" #include "src/gllo/common/pattern_engine.h" +#include "schema/inner/model_generated.h" namespace mindspore { namespace opt { @@ -42,6 +43,8 @@ void CheckIfVarIsNull(const VarPtr &var); void CheckInputSize(const CNodePtr &node, const int size); +schema::PrimitiveType GetCNodeType(const CNodePtr &node); + } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ diff --git a/mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc new file mode 100644 index 0000000000..46d1273f99 --- /dev/null +++ b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/gllo/fusion/conv_activation_fusion.h" +#include +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "mindspore/ccsrc/utils/utils.h" +#include "src/gllo/common/utils.h" + +namespace mindspore { +namespace opt { +const BaseRef ConvActivationFusion::DefinePattern() const { + VarPtr X = std::make_shared(); + // conv2d inputs may be 2 or 3 inputs,match move to process + auto prim = new schema::PrimitiveT(); + prim->value.type = primitive_type; + auto prim_value = std::make_shared(prim); + + return VectorRef({prim_value, X}); +} + +const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + auto act_node = node->cast(); + CheckIfCNodeIsNull(act_node); + CheckInputSize(act_node, 2); + + auto act_primitive = GetValueNode>(act_node->input(0)); + if (act_primitive->GetPrimitiveT()->value.AsActivation()->type != activation_type) { + return node; + } + AnfNodePtr pre_node = act_node->input(1); + CheckIfAnfNodeIsNull(pre_node); + if (pre_node != nullptr && pre_node->isa()) { + auto conv_node = pre_node->cast(); + auto node_type = GetCNodeType(conv_node); + if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) { + auto primitiveT_value = GetValueNode>(conv_node->input(0)); + primitiveT_value->GetPrimitiveT()->value.AsConv2D()->activationType = activation_type; + return pre_node; + } + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/src/gllo/fusion/conv_activation_fusion.h b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.h new file mode 100644 index 0000000000..a3e605db44 --- /dev/null +++ b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ + +#include "src/gllo/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvActivationFusion : public PatternProcessPass { + public: + explicit ConvActivationFusion(bool multigraph = true, + schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, + schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) : primitive_type( + primitive), activation_type(activation), PatternProcessPass("conv_activation_fusion", multigraph) {} + ~ConvActivationFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + schema::PrimitiveType primitive_type; + schema::ActivationType activation_type; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ diff --git a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc index 0b20a2baf8..ca03d96367 100644 --- a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc @@ -15,8 +15,8 @@ */ #include "src/gllo/fusion/conv_biasadd_fusion.h" #include -#include "mindspore/lite/schema/inner/model_generated.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" #include "mindspore/ccsrc/utils/utils.h" #include "src/gllo/common/utils.h" diff --git a/mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc new file mode 100644 index 0000000000..6d6d81d8b0 --- /dev/null +++ b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/gllo/fusion/conv_scale_fusion.h" +#include +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "src/param_value_lite.h" +#include "mindspore/ccsrc/utils/utils.h" +#include "src/gllo/common/utils.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace opt { +const BaseRef ConvScaleFusion::DefinePattern() const { + VarPtr X = std::make_shared(); + // conv2d inputs may be 2 or 3 inputs,match move to process + auto prim = new schema::PrimitiveT(); + prim->value.type = schema::PrimitiveType_Scale; + auto prim_value = std::make_shared(prim); + + return VectorRef({prim_value, X}); +} + +const AnfNodePtr ConvScaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "conv activation pass process"; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + auto scale_node = node->cast(); + CheckIfCNodeIsNull(scale_node); + CheckInputSize(scale_node, 2); + + AnfNodePtr pre_node = scale_node->input(1); + CheckIfAnfNodeIsNull(pre_node); + if (pre_node != nullptr && pre_node->isa()) { + auto conv_node = pre_node->cast(); + auto node_type = GetCNodeType(conv_node); + if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) { + return DoFusion(conv_node, scale_node); + } + } + return node; +} +const AnfNodePtr ConvScaleFusion::DoFusion(const CNodePtr &conv_node, const CNodePtr &scale_node) const { + if (scale_node->inputs().size() == 3) { + GetTransParam(scale_node->input(2), nullptr); + } else if (scale_node->inputs().size() == 4) { + // todo add bias fusion zhengjun10 + GetTransParam(scale_node->input(2), scale_node->input(3)); + } else { + MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString(); + return nullptr; + } + + AnfNodePtr conv_weight_node; + if (conv_node->inputs().size() == 3) { + conv_weight_node = conv_node->input(2); + } else { + MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString(); + return nullptr; + } + auto conv_weight_param = conv_weight_node->cast()->default_param(); + auto weight_value = std::dynamic_pointer_cast(conv_weight_param); + auto old_conv_weight = reinterpret_cast(weight_value->tensor_addr()); + + auto new_conv_weight = new(std::nothrow) float[weight_value->tensor_shape_size()]; + CalNewWeightTensor(old_conv_weight, new_conv_weight, weight_value->tensor_shape_size()); + weight_value->set_tensor_addr(new_conv_weight); + return conv_node; +} + +const lite::STATUS ConvScaleFusion::GetTransParam(const AnfNodePtr &scale_weight_node, + const AnfNodePtr &scale_bias_node) const { + if (!scale_weight_node->isa()) { + MS_LOG(EXCEPTION) << "scale weight node not paramter node"; + } + if (scale_bias_node != nullptr && !scale_bias_node->isa()) { + MS_LOG(EXCEPTION) << "scale bias node not paramter node"; + } + auto scale_weight_param = scale_weight_node->cast()->default_param(); + auto weight_value = std::dynamic_pointer_cast(scale_weight_param); + auto weight_data = reinterpret_cast(weight_value->tensor_addr()); + + if (0 != memcpy_s(trans_scale, kernel_nums * sizeof(float), weight_data, kernel_nums * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s transScale failed"; + return lite::RET_ERROR; + } + return lite::RET_OK; +} + +const lite::STATUS ConvScaleFusion::CalNewWeightTensor(const float *oldWeightTensor, float *newWeightTensor, + const size_t tensor_shape_size) const { + MS_ASSERT(oldWeightTensor != nullptr); + if (0 != memset_s(newWeightTensor, tensor_shape_size * sizeof(float), 0, tensor_shape_size * sizeof(float))) { + MS_LOG(ERROR) << "memset newWeightData failed"; + return lite::RET_ERROR; + } + if (kernel_nums == 0) { + MS_LOG(ERROR) << "kernel nums is 0"; + return lite::RET_ERROR; + } + auto kernel_size = tensor_shape_size / kernel_nums; + for (size_t i = 0; i < kernel_nums; i++) { + for (size_t j = 0; j < kernel_size; j++) { + newWeightTensor[i * kernel_size + j] = oldWeightTensor[i * kernel_size + j] * trans_scale[i]; + } + } + return lite::RET_OK; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/src/gllo/fusion/conv_scale_fusion.h b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.h new file mode 100644 index 0000000000..8012fc2a03 --- /dev/null +++ b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ + +#include "src/gllo/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvScaleFusion : public PatternProcessPass { + public: + explicit ConvScaleFusion(bool multigraph = true) : PatternProcessPass("conv_scale_fusion", multigraph) {} + ~ConvScaleFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + const AnfNodePtr DoFusion(const CNodePtr &, const CNodePtr &) const; + const lite::STATUS GetTransParam(const AnfNodePtr &, const AnfNodePtr &) const; + const lite::STATUS CalNewWeightTensor(const float *, float *, const size_t) const; + private: + float *trans_scale = nullptr; + int kernel_nums = 0; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ + diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 7311bfec1a..f8b197e248 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -193,6 +193,8 @@ if(BUILD_CONVERTER) ${LITE_DIR}/src/gllo/common/visit.cc ${LITE_DIR}/src/gllo/common/utils.cc ${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc + ${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc + ${LITE_DIR}/src/gllo/fusion/conv_scale_fusion.cc ) endif() ### train diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 856e730b87..cbca13f231 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -78,6 +78,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/visit.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_scale_fusion.cc ) add_subdirectory(parser/caffe)