!14084 [GraphKernel] support matmul on D

From: @lingyunli63
Reviewed-by: 
Signed-off-by:
pull/14084/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ad140a8bf4

@ -466,6 +466,8 @@ class GraphSplitAscend(GraphSplitByPattern):
REDUCE_FUSE_DEPTH = 10
def get_default_mode(self, op):
if op.prim == "MatMul":
return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC
if op.prim in ("Tile", "BroadcastTo"):
return self.Area.MODE_COMPOSITE
return self.Area.MODE_BASIC

@ -88,8 +88,7 @@ class PrimLib:
ELEMWISE = 2
BROADCAST = 3
REDUCE = 4
TRANSFORM = 5
CONTROL = 6
OPAQUE = 5
class Prim:
"""Prim"""
@ -146,7 +145,6 @@ class PrimLib:
default_elemwise_broadcast_relation,
default_reduce_relation,
unknown_relation,
unknown_relation,
]
primtives = {
@ -176,7 +174,6 @@ class PrimLib:
'ReduceSum': Prim(REDUCE),
'ReduceMax': Prim(REDUCE),
'ReduceMin': Prim(REDUCE),
'MakeTuple': Prim(CONTROL),
'Assign': Prim(ELEMWISE),
'Tanh': Prim(ELEMWISE),
'ExpandDims': Prim(RESHAPE),
@ -186,9 +183,10 @@ class PrimLib:
'Squeeze': Prim(RESHAPE),
'Flatten': Prim(RESHAPE),
'FlattenGrad': Prim(RESHAPE),
'Transpose': Prim(TRANSFORM),
'Transpose': Prim(OPAQUE),
'Tile': Prim(BROADCAST),
'BroadcastTo': Prim(BROADCAST),
'MatMul': Prim(OPAQUE),
}
default_primtive = Prim(UNKNOWN)
@ -509,7 +507,7 @@ class AddControlBuddy(GraphVisitor):
self.buddies = {} # {op : [ctrl_op]}
def visit(self, op):
if PrimLib.iter_type(op) == PrimLib.CONTROL:
if op.prim == "MakeTuple":
assert len(op.output.to_ops) == 1
owner = op.output.to_ops[0]
if owner in self.buddies:

@ -177,6 +177,18 @@ void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) {
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node);
}
void SetAkgAttrsForMatMul(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string dst_type;
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0);
dst_type = TypeId2String(output_type);
AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node);
auto left_format = AnfAlgo::GetInputFormat(anf_node, 0);
auto right_format = AnfAlgo::GetInputFormat(anf_node, 1);
AnfAlgo::SetNodeAttr("left_format", MakeValue(left_format), anf_node);
AnfAlgo::SetNodeAttr("right_format", MakeValue(right_format), anf_node);
}
const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = {
{kFour2FiveOpName, SetAkgAttrsForFour2Five},
{kFive2FourOpName, SetAkgAttrsForFive2Four},
@ -190,6 +202,7 @@ const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_n
{kConvBN1OpName, SetAkgAttrsForConvBN1},
{kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu},
{kBN2ReLUOpName, SetAkgAttrsForBN2Relu},
{kMatMulOpName, SetAkgAttrsForMatMul},
};
} // namespace

@ -575,7 +575,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimCast, prim::kPrimRealDiv};
prim::kPrimCast, prim::kPrimRealDiv, prim::kPrimMatMul};
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,

@ -265,6 +265,7 @@ constexpr auto kSGDName = "SGD";
constexpr auto kLARSUpdateName = "LARSUpdate";
constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad";
constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2";
constexpr auto kMatMulOpName = "MatMul";
constexpr auto kMatMulV2OpName = "MatMulV2";
constexpr auto kBroadcastToOpName = "BroadcastTo";
constexpr auto kFusedAddReluV2Name = "FusedAddReluV2";

@ -0,0 +1,88 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul(transpose_a=True, transpose_b=True)
def construct(self, x, y):
return self.matmul(x, y)
class Net1(Cell):
def __init__(self):
super(Net1, self).__init__()
self.matmul = P.MatMul(transpose_a=True, transpose_b=True)
self.add = P.BiasAdd()
def construct(self, x, y, bias):
res = self.matmul(x, y)
return self.add(res, bias)
def get_output(i0, i1, enable_graph_kernel=False):
if enable_graph_kernel:
context.set_context(enable_graph_kernel=True, save_graphs=False)
net = Net()
output = net(i0, i1)
return output
def get_output1(i0, i1, i2, enable_graph_kernel=False):
if enable_graph_kernel:
context.set_context(enable_graph_kernel=True, save_graphs=False)
net = Net1()
output = net(i0, i1, i2)
return output
def test_basic():
i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16))
i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16))
expect = get_output(i0, i1, False)
output = get_output(i0, i1, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
def test_basic1():
i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16))
i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16))
i2 = Tensor(np.random.normal(100, 0.01, [128,]).astype(np.float16))
expect = get_output1(i0, i1, i2, False)
output = get_output1(i0, i1, i2, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 6.e-4, 6.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_basic()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend1():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_basic1()
Loading…
Cancel
Save