|
|
|
@ -11,8 +11,9 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
# ============================================================================
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
|
import mindspore as ms
|
|
|
|
|
from mindspore import context, Tensor, Parameter
|
|
|
|
|
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
|
|
|
@ -24,7 +25,7 @@ from mindspore.common.initializer import initializer
|
|
|
|
|
class Net(Cell):
|
|
|
|
|
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.begin_norm_axis = -1
|
|
|
|
|
self.begin_norm_axis = 2
|
|
|
|
|
self.begin_params_axis = 1
|
|
|
|
|
self.mul = P.Mul().set_strategy(strategy1)
|
|
|
|
|
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
|
|
|
|
@ -64,18 +65,18 @@ def test_layer_norm_data_parallel():
|
|
|
|
|
|
|
|
|
|
def test_layer_norm_model_parallel():
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
|
|
|
|
strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1))
|
|
|
|
|
strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1))
|
|
|
|
|
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1))
|
|
|
|
|
strategy1 = ((1, 16, 1, 1), (1, 16, 1, 1))
|
|
|
|
|
strategy2 = ((1, 16, 1, 1), (16, 1, 1), (16, 1, 1))
|
|
|
|
|
strategy3 = ((1, 16, 1, 1), (1, 16, 1, 1))
|
|
|
|
|
net = Net(_w, strategy1, strategy2, strategy3)
|
|
|
|
|
compile(net)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_layer_norm_hybrid_parallel():
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
|
|
|
|
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
|
|
|
|
strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1))
|
|
|
|
|
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
|
|
|
|
strategy1 = ((2, 8, 1, 1), (2, 8, 1, 1))
|
|
|
|
|
strategy2 = ((2, 8, 1, 1), (8, 1, 1), (8, 1, 1))
|
|
|
|
|
strategy3 = ((2, 8, 1, 1), (2, 8, 1, 1))
|
|
|
|
|
net = Net(_w, strategy1, strategy2, strategy3)
|
|
|
|
|
compile(net)
|
|
|
|
|
|
|
|
|
@ -89,8 +90,17 @@ def test_layer_norm_auto_parallel():
|
|
|
|
|
def test_layer_norm_repeat_calc():
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
|
|
|
|
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
|
|
|
|
strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1))
|
|
|
|
|
strategy2 = ((2, 2, 1, 1), (2, 1, 1), (2, 1, 1))
|
|
|
|
|
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
|
|
|
|
net = Net(_w, strategy1, strategy2, strategy3)
|
|
|
|
|
compile(net)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_layer_norm_wrong_strategy():
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
|
|
|
|
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
|
|
|
|
strategy2 = ((1, 2, 1, 2), (2, 1, 2), (2, 1, 2))
|
|
|
|
|
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
|
|
|
|
net = Net(_w, strategy1, strategy2, strategy3)
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
|
compile(net)
|
|
|
|
|