From f11336a43111904d6632d1e8496eed5805090415 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Thu, 22 Oct 2020 19:33:46 +0800 Subject: [PATCH] transformer batchmatmul fusion --- mindspore/lite/src/lite_kernel.h | 1 + .../src/runtime/kernel/arm/base/dequant.h | 16 +- mindspore/lite/test/CMakeLists.txt | 1 + mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 2 + .../optimizer/fusion/batchmatmul_fusion.cc | 189 ++++++++++++++++++ .../optimizer/fusion/batchmatmul_fusion.h | 34 ++++ 7 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc create mode 100644 mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.h diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 8f81cada97..fb05aadc44 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -31,6 +31,7 @@ #include "include/errorcode.h" static constexpr int kPerTensor = 1; +static constexpr size_t kPerBatch = 3; namespace mindspore::kernel { enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h index a130061034..84265ded59 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h @@ -39,7 +39,21 @@ class DequantUtil { MS_LOG(ERROR) << "Malloc failed."; return nullptr; } - if (input_tensor->GetQuantParams().size() != kPerTensor) { + if (input_tensor->shape().size() == kPerBatch && + input_tensor->GetQuantParams().size() == static_cast(input_tensor->shape()[0])) { // per batch matmul + auto per_batch_size = input_tensor->shape()[0]; + auto quant_param = input_tensor->GetQuantParams(); + for (int i = 0; i < per_batch_size; i++) { + auto param = quant_param.at(i); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + auto matrix_size = input_tensor->ElementsNum() / per_batch_size; + for (int64_t j = 0; j < matrix_size; j++) { + dequant_datas[i * matrix_size + j] = + static_cast((quant_datas[i * matrix_size + j] - zero_point) * scale); + } + } + } else if (input_tensor->GetQuantParams().size() != kPerTensor) { size_t channels = static_cast(input_tensor->Batch()); if (input_tensor->GetQuantParams().size() != channels) { MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 39f4adf625..9fcea8a0f4 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -180,6 +180,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/layer_norm_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 199bcc3a68..d8610213de 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -40,6 +40,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/constant_folding_fusion.cc ../optimizer/fusion/quant_dtype_cast_fusion.cc ../optimizer/fusion/layer_norm_fusion.cc + ../optimizer/fusion/batchmatmul_fusion.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 9738bc13b6..e2c352273b 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -26,6 +26,7 @@ #include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/fusion/layer_norm_fusion.h" +#include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" @@ -59,6 +60,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(true, "conv_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, diff --git a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc new file mode 100644 index 0000000000..850d2fc6cd --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -0,0 +1,189 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/batchmatmul_fusion.h" +#include +#include +#include "src/ops/primitive_c.h" +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" + +namespace mindspore::opt { +namespace { +bool IsStackNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_Stack; + } + return false; +} +bool IsFullConnectNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_FullConnection; + } + return false; +} +void *GetInputAddr(const AnfNodePtr &node, size_t input_index) { + MS_ASSERT(node != nullptr); + if (!node->isa()) { + MS_LOG(ERROR) << "GetInputAddr not cnode"; + return nullptr; + } + auto cnode = node->cast(); + if (input_index >= cnode->inputs().size()) { + MS_LOG(ERROR) << "input index error"; + return nullptr; + } + if (cnode->input(input_index)->isa()) { + auto param_input = cnode->input(input_index)->cast(); + auto param_value = std::dynamic_pointer_cast(param_input->default_param()); + if (param_value == nullptr) { + MS_LOG(ERROR) << "param not paramValueLite"; + return nullptr; + } + return param_value->tensor_addr(); + } + MS_LOG(ERROR) << "input not paramter"; + return nullptr; +} +STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPtr &rmatmul_input) { + MS_ASSERT(stack_node != nullptr); + MS_ASSERT(right_matmul_input != nullptr); + auto joint_fullconnect_size = stack_node->inputs().size() - 1; + auto fc = stack_node->input(1)->cast(); + auto fc_weight = fc->input(2)->cast(); + auto fc_weight_param = std::dynamic_pointer_cast(fc_weight->default_param()); + auto tensor_size = fc_weight_param->tensor_size(); + auto rmatmul_input_shape = fc_weight_param->tensor_shape(); + auto new_tensor_data = new (std::nothrow) int8_t[joint_fullconnect_size * tensor_size]; + if (new_tensor_data == nullptr) { + MS_LOG(ERROR) << "tensor_data is nullptr"; + return RET_ERROR; + } + for (size_t i = 1; i < joint_fullconnect_size + 1; i++) { + auto tensor_addr = GetInputAddr(stack_node->input(i), 2); + if (tensor_addr == nullptr) { + MS_LOG(ERROR) << "input tensor addr nullptr"; + return RET_ERROR; + } + if (EOK != memcpy_s(new_tensor_data + (i - 1) * tensor_size, tensor_size, tensor_addr, tensor_size)) { + MS_LOG(ERROR) << "memcpy_s data failed"; + return RET_ERROR; + } + } + rmatmul_input_shape.insert(rmatmul_input_shape.begin(), joint_fullconnect_size); + auto type_ptr = TypeIdToType(fc_weight_param->tensor_type()); + auto abstract_tensor = std::make_shared(type_ptr, rmatmul_input_shape); + rmatmul_input->set_abstract(abstract_tensor); + rmatmul_input->set_name(stack_node->fullname_with_scope() + "right_parameter"); + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_shape(rmatmul_input_shape); + param_value->set_tensor_type(fc_weight_param->tensor_type()); + param_value->set_format(fc_weight_param->format()); + param_value->set_tensor_addr(new_tensor_data); + param_value->set_tensor_size(joint_fullconnect_size * tensor_size); + rmatmul_input->set_default_param(param_value); + return RET_OK; +} +} // namespace +const BaseRef BatchMatMulFusion::DefinePattern() const { + auto pack_var = std::make_shared(IsStackNode); + auto left_fullconnect_var = std::make_shared(IsFullConnectNode); + auto right_fullconnect_var = std::make_shared(IsFullConnectNode); + auto other_fullconnect_var = std::make_shared(); + return VectorRef({pack_var, left_fullconnect_var, right_fullconnect_var, other_fullconnect_var}); +} + +// slice +fullconnect ->batchmatmul +const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + auto stack_cnode = node->cast(); + // check stack node all inputs must fullconnect + for (size_t i = 1; i < stack_cnode->inputs().size(); i++) { + auto input_node = stack_cnode->input(i); + if (!IsFullConnectNode(input_node)) { + MS_LOG(WARNING) << "batchmatmulfusion stack node all inputs must fullconnect type"; + return nullptr; + } + } + auto fullconnect_node = stack_cnode->input(1); + MS_ASSERT(fullconnnect_node != nullptr); + auto fullconnect_cnode = fullconnect_node->cast(); + MS_ASSERT(fullconnect_cnode->inputs().size() == 3); + auto left_slice_node = fullconnect_cnode->input(1); + auto left_slice_cnode = left_slice_node->cast(); + auto left_matmul_input = left_slice_cnode->input(1); + auto right_reshape_node = fullconnect_cnode->input(2); + + auto matmul_primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + matmul_primitive->value.type = schema::PrimitiveType_MatMul; + matmul_primitive->value.value = attr.release(); + auto matmul_cvalue = lite::PrimitiveC::Create(matmul_primitive.release()); + // get matmul quantParams + std::vector jointed_quant_params; + for (int i = 1; i < 9; i++) { + auto fullconnect_node2 = stack_cnode->input(i)->cast(); + auto fc_prim = GetValueNode>(fullconnect_node2->input(0)); + auto fc_input_quantParams = fc_prim->GetInputQuantParams(); + if (fc_input_quantParams.size() > 1 && !fc_input_quantParams[1].empty()) { + jointed_quant_params.push_back(fc_input_quantParams[1][0]); + } + } + auto fc_prim = GetValueNode>(fullconnect_cnode->input(0)); + auto rmatmul_quant_params = fc_prim->GetInputQuantParams(); + rmatmul_quant_params.pop_back(); + rmatmul_quant_params.pop_back(); + // no bias quantParams + rmatmul_quant_params.emplace_back(jointed_quant_params); + matmul_cvalue->SetInputQuantParam(rmatmul_quant_params); + matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams()); + auto matmul_value_node = NewValueNode(std::shared_ptr(matmul_cvalue)); + std::vector matmul_inputs = {matmul_value_node, left_matmul_input}; + + // batchmatmul right node may be const + if (right_reshape_node->isa()) { + // return stack_cnode; + auto rmatmul_paramter = func_graph->add_parameter(); + if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) { + MS_LOG(ERROR) << "GetRightMatmulInputParamter failed"; + return node; + } + auto prim = GetValueNode>(matmul_value_node); + prim->GetPrimitiveT()->value.AsMatMul()->transposeB = true; + matmul_inputs.push_back(rmatmul_paramter); + } else { + auto right_reshape_cnode = right_reshape_node->cast(); + MS_ASSERT(right_reshape_cnode->inputs().size() > 1); + auto right_transpose_node = right_reshape_cnode->input(1); + auto right_transpose_cnode = right_transpose_node->cast(); + auto right_slice_node = right_transpose_cnode->input(1); + auto right_slice_cnode = right_slice_node->cast(); + auto right_matmul_input = right_slice_cnode->input(1); + matmul_inputs.push_back(right_matmul_input); + } + auto matmul_cnode = func_graph->NewCNode(matmul_inputs); + matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope()); + MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success"; + return matmul_cnode; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.h b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.h new file mode 100644 index 0000000000..9dff451261 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_BATCHMATMUL_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_BATCHMATMUL_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "tools/converter/converter_context.h" + +namespace mindspore { +namespace opt { +class BatchMatMulFusion : public PatternProcessPass { + public: + explicit BatchMatMulFusion(bool multigraph = true) : PatternProcessPass("slice_fullconnect_fusion", multigraph) {} + ~BatchMatMulFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_BATCHMATMUL_FUSION_H_