From 8b3823b22cd2e5a04285dfb90355709f95769917 Mon Sep 17 00:00:00 2001 From: lingyunli63 Date: Tue, 30 Mar 2021 12:18:27 +0800 Subject: [PATCH] optimizeMatmul --- .../graph_kernel/graph_kernel_optimization.cc | 8 ++- .../optimizer/graph_kernel/optimize_matmul.cc | 64 +++++++++++++++++++ .../optimizer/graph_kernel/optimize_matmul.h | 36 +++++++++++ tests/st/ops/graph_kernel/test_matmul_cast.py | 60 +++++++++++++++++ 4 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.cc create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.h create mode 100644 tests/st/ops/graph_kernel/test_matmul_cast.py diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index c968ef0768..1e39d4167c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -30,6 +30,7 @@ #include "backend/optimizer/graph_kernel/tensor_promotion.h" #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" +#include "backend/optimizer/graph_kernel/optimize_matmul.h" #include "backend/optimizer/graph_kernel/raise_reduction_precision.h" #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" @@ -49,8 +50,11 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() { // Change Assign(p, a, U) to Assign(Depend(p, U), a) pm->AddPass(std::make_shared()); - // Reorder TransData-Cast to Cast-TransData, if (is_ascend) { + // Remove redundant Cast(bias, fp16) for Matmul input + pm->AddPass(std::make_shared()); + + // Reorder TransData-Cast to Cast-TransData pm->AddPass(std::make_shared()); } @@ -81,7 +85,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - // Cast the input of ReduceSum from float16 to float32 for higher precision*/ + // Cast the input of ReduceSum from float16 to float32 for higher precision pm->AddPass(std::make_shared()); // Universal arithmetic simplify diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.cc new file mode 100644 index 0000000000..94106034d6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/optimize_matmul.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" + +namespace mindspore { +namespace opt { +/* MatMul supports fp32 bias, so remove the redundant cast when cast only used by MatMul + * + * %0 = cast(bias_fp32, fp16) + * %1 = MatMul(A_fp16, B_fp16, %0) + * ------> + * %1 = MatMul(A_fp16, B_fp16, bias_fp32) + */ +bool OptimizeMatmul::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + auto changed = false; + auto nodes = TopoSort(func_graph->get_return()); + for (auto node : nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) { + continue; + } + auto cnode = node->cast(); + if (cnode->size() != 4) { + continue; + } + auto cast_node = cnode->input(3); + if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) { + continue; + } + auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); + auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); + if (cast_input_type == kNumberTypeFloat32 && cast_output_type == kNumberTypeFloat16 && + mng->node_users()[cast_node].size() == 1) { + mng->Replace(cast_node, (cast_node->cast())->input(1)); + changed = true; + } + } + + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.h new file mode 100644 index 0000000000..6f607f1a26 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ + +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" + +namespace mindspore { +namespace opt { +class OptimizeMatmul : public Pass { + public: + OptimizeMatmul() : Pass("optimize_matmul") {} + ~OptimizeMatmul() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +using OptimizeMatmulPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ diff --git a/tests/st/ops/graph_kernel/test_matmul_cast.py b/tests/st/ops/graph_kernel/test_matmul_cast.py new file mode 100644 index 0000000000..49731fbfff --- /dev/null +++ b/tests/st/ops/graph_kernel/test_matmul_cast.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P +from mindspore.common import dtype as mstype + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul(transpose_a=True, transpose_b=True) + self.add = P.BiasAdd() + self.cast = P.Cast() + + def construct(self, x, y, b): + xy = self.matmul(x, y) + b16 = self.cast(b, mstype.float16) + res = self.add(xy, b16) + return self.cast(res, mstype.float32) + +def get_output(i0, i1, i2, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True, save_graphs=False) + net = Net() + output = net(i0, i1, i2) + return output + +def test_basic(): + i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16)) + i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16)) + i2 = Tensor(np.random.normal(100, 0.1, [128,]).astype(np.float32)) + expect = get_output(i0, i1, i2, False) + output = get_output(i0, i1, i2, True) + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_basic_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic()