diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index e44b5dd80d..f72a06511f 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -466,6 +466,8 @@ class GraphSplitAscend(GraphSplitByPattern): REDUCE_FUSE_DEPTH = 10 def get_default_mode(self, op): + if op.prim == "MatMul": + return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC if op.prim in ("Tile", "BroadcastTo"): return self.Area.MODE_COMPOSITE return self.Area.MODE_BASIC diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index e11b38ed07..b630c69ec7 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -88,8 +88,7 @@ class PrimLib: ELEMWISE = 2 BROADCAST = 3 REDUCE = 4 - TRANSFORM = 5 - CONTROL = 6 + OPAQUE = 5 class Prim: """Prim""" @@ -146,7 +145,6 @@ class PrimLib: default_elemwise_broadcast_relation, default_reduce_relation, unknown_relation, - unknown_relation, ] primtives = { @@ -176,7 +174,6 @@ class PrimLib: 'ReduceSum': Prim(REDUCE), 'ReduceMax': Prim(REDUCE), 'ReduceMin': Prim(REDUCE), - 'MakeTuple': Prim(CONTROL), 'Assign': Prim(ELEMWISE), 'Tanh': Prim(ELEMWISE), 'ExpandDims': Prim(RESHAPE), @@ -186,9 +183,10 @@ class PrimLib: 'Squeeze': Prim(RESHAPE), 'Flatten': Prim(RESHAPE), 'FlattenGrad': Prim(RESHAPE), - 'Transpose': Prim(TRANSFORM), + 'Transpose': Prim(OPAQUE), 'Tile': Prim(BROADCAST), 'BroadcastTo': Prim(BROADCAST), + 'MatMul': Prim(OPAQUE), } default_primtive = Prim(UNKNOWN) @@ -509,7 +507,7 @@ class AddControlBuddy(GraphVisitor): self.buddies = {} # {op : [ctrl_op]} def visit(self, op): - if PrimLib.iter_type(op) == PrimLib.CONTROL: + if op.prim == "MakeTuple": assert len(op.output.to_ops) == 1 owner = op.output.to_ops[0] if owner in self.buddies: diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc index 7067ab74c2..069fd68cd6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc @@ -177,6 +177,18 @@ void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) { AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); } +void SetAkgAttrsForMatMul(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string dst_type; + TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + dst_type = TypeId2String(output_type); + AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); + auto left_format = AnfAlgo::GetInputFormat(anf_node, 0); + auto right_format = AnfAlgo::GetInputFormat(anf_node, 1); + AnfAlgo::SetNodeAttr("left_format", MakeValue(left_format), anf_node); + AnfAlgo::SetNodeAttr("right_format", MakeValue(right_format), anf_node); +} + const std::unordered_map> kAkgKernelAttrsProcessMap = { {kFour2FiveOpName, SetAkgAttrsForFour2Five}, {kFive2FourOpName, SetAkgAttrsForFive2Four}, @@ -190,6 +202,7 @@ const std::unordered_map GetFusibleOpList() { prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, - prim::kPrimCast, prim::kPrimRealDiv}; + prim::kPrimCast, prim::kPrimRealDiv, prim::kPrimMatMul}; #elif ENABLE_GPU std::vector fusible_basic_ops = { prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 6077ea46e9..1fc50d6298 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -265,6 +265,7 @@ constexpr auto kSGDName = "SGD"; constexpr auto kLARSUpdateName = "LARSUpdate"; constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; +constexpr auto kMatMulOpName = "MatMul"; constexpr auto kMatMulV2OpName = "MatMulV2"; constexpr auto kBroadcastToOpName = "BroadcastTo"; constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; diff --git a/tests/st/ops/graph_kernel/test_matmul.py b/tests/st/ops/graph_kernel/test_matmul.py new file mode 100644 index 0000000000..777b0aeba7 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_matmul.py @@ -0,0 +1,88 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul(transpose_a=True, transpose_b=True) + + def construct(self, x, y): + return self.matmul(x, y) + +class Net1(Cell): + def __init__(self): + super(Net1, self).__init__() + self.matmul = P.MatMul(transpose_a=True, transpose_b=True) + self.add = P.BiasAdd() + + def construct(self, x, y, bias): + res = self.matmul(x, y) + return self.add(res, bias) + +def get_output(i0, i1, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True, save_graphs=False) + net = Net() + output = net(i0, i1) + return output + +def get_output1(i0, i1, i2, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True, save_graphs=False) + net = Net1() + output = net(i0, i1, i2) + return output + +def test_basic(): + i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16)) + i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16)) + expect = get_output(i0, i1, False) + output = get_output(i0, i1, True) + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) + +def test_basic1(): + i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16)) + i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16)) + i2 = Tensor(np.random.normal(100, 0.01, [128,]).astype(np.float16)) + expect = get_output1(i0, i1, i2, False) + output = get_output1(i0, i1, i2, True) + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 6.e-4, 6.e-4) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_basic_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic() + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_basic_ascend1(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic1()