diff --git a/tests/ut/python/parallel/test_attention.py b/tests/ut/python/parallel/test_attention.py deleted file mode 100644 index a484e1fd63..0000000000 --- a/tests/ut/python/parallel/test_attention.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import context -from mindspore.common.api import _executor -from mindspore.context import set_auto_parallel_context -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from tests.ut.python.ops.test_math_ops import VirtualLoss - - -grad_all = C.GradOperation(get_all=True) - - -class NetWithLoss(nn.Cell): - def __init__(self, network): - super(NetWithLoss, self).__init__() - self.loss = VirtualLoss() - self.network = network - - def construct(self, x): - predict = self.network(x) - return self.loss(predict) - - -class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - - def construct(self, x): - return grad_all(self.network)(x) - - -def compile_net(net, x): - net.set_auto_parallel() - _executor.compile(net, x) - - -class Net(nn.Cell): - def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5): - super().__init__() - self.query_w = Parameter(initializer( - "normal", [8, 16], ms.float32), name='query') - self.query = P.MatMul().shard(strategy1) - - self.key_w = Parameter(initializer( - "normal", [8, 16], ms.float32), name='key') - self.key = P.MatMul().shard(strategy2) - - self.value_w = Parameter(initializer( - "normal", [8, 16], ms.float32), name='value') - self.value = P.MatMul().shard(strategy3) - - self.score = P.MatMul().shard(strategy4) - self.context = P.MatMul().shard(strategy5) - self.transpose1 = P.Transpose() - self.transpose2 = P.Transpose() - self.relu = P.ReLU() - - def construct(self, x): - q = self.query(x, self.query_w) - k = self.key(x, self.key_w) - v = self.value(x, self.value_w) - - k = self.transpose1(k, (1, 0)) - s = self.score(q, k) - - v = self.transpose2(v, (1, 0)) - c = self.context(v, s) - out = self.relu(c) - - return out - - -def test_self_attention_standalone(): - set_auto_parallel_context(device_num=8, global_rank=0) - context.set_auto_parallel_context(parallel_mode="stand_alone") - net = GradWrap(NetWithLoss( - Net(None, None, None, None, None))) - - x = Tensor(np.ones([32, 8]), dtype=ms.float32) - - compile_net(net, x) - - -def test_self_attention_semi(): - set_auto_parallel_context(device_num=8, global_rank=0) - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - - strategy1 = ((2, 2), (2, 2)) - strategy2 = ((2, 2), (2, 2)) - strategy3 = ((2, 2), (2, 2)) - strategy4 = ((2, 4), (4, 1)) - strategy5 = ((2, 1), (1, 4)) - - net = GradWrap(NetWithLoss( - Net(strategy1, strategy2, strategy3, strategy4, strategy5))) - - x = Tensor(np.ones([32, 8]), dtype=ms.float32) - - compile_net(net, x) - - -def test_self_attention_dp(): - set_auto_parallel_context(device_num=8, global_rank=0) - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - - strategy1 = ((8, 1), (1, 1)) - strategy2 = ((8, 1), (1, 1)) - strategy3 = ((8, 1), (1, 1)) - strategy4 = ((8, 1), (1, 1)) - strategy5 = ((8, 1), (1, 1)) - - net = GradWrap(NetWithLoss( - Net(strategy1, strategy2, strategy3, strategy4, strategy5))) - - x = Tensor(np.ones([32, 8]), dtype=ms.float32) - - compile_net(net, x) - - -def test_self_attention_auto(): - set_auto_parallel_context(device_num=8, global_rank=0) - context.set_auto_parallel_context(parallel_mode="auto_parallel") - net = GradWrap(NetWithLoss( - Net(None, None, None, None, None))) - - x = Tensor(np.ones([32, 8]), dtype=ms.float32) - - compile_net(net, x)