parent
dada71eec3
commit
6d522f0a4f
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_
|
||||
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "ir/value.h"
|
||||
#include "parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "parallel/ops_info/operator_info.h"
|
||||
#include "parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
constexpr size_t LAYER_NORM_INPUT_SIZE = 3;
|
||||
constexpr size_t LAYER_NORM_INPUT_INDEX = 0;
|
||||
constexpr size_t LAYER_NORM_GAMMA_INDEX = 1;
|
||||
constexpr size_t LAYER_NORM_BETA_INDEX = 2;
|
||||
constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis";
|
||||
|
||||
// The dimensions of input tensor starting from begin norm axis cannot be split. Other dimensions can be split
|
||||
// arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add.
|
||||
class LayerNormInfo : public OperatorInfo {
|
||||
public:
|
||||
LayerNormInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>(true)),
|
||||
begin_norm_axis_(0) {}
|
||||
~LayerNormInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||
Status GenerateStrategies(int32_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr&) override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferAsLossDivisor() override;
|
||||
Status CreateTensorMap(size_t input_index);
|
||||
Status CreateTensorInfo(size_t input_index);
|
||||
Status CreateMirrorOp(size_t input_index);
|
||||
Status GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr>& sp_vector);
|
||||
Status InitShapes();
|
||||
|
||||
private:
|
||||
size_t begin_norm_axis_;
|
||||
Shape input_shape_;
|
||||
Shape gamma_shape_;
|
||||
Shape beta_shape_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_
|
@ -0,0 +1,96 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
|
||||
super().__init__()
|
||||
self.begin_norm_axis = -1
|
||||
self.begin_params_axis = 1
|
||||
self.mul = P.Mul().set_strategy(strategy1)
|
||||
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
|
||||
self.mul2 = P.Mul().set_strategy(strategy3)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
self.normalized_shape = [64, 32, 16]
|
||||
self.gamma = Parameter(initializer('ones', self.normalized_shape), name="gamma")
|
||||
self.beta = Parameter(initializer('zeros', self.normalized_shape), name="beta")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out, _, _ = self.layer_norm(out, self.gamma, self.beta)
|
||||
out = self.mul2(out, b)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
|
||||
_w = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_layer_norm_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1, 1), (16, 1, 1, 1))
|
||||
strategy2 = ((16, 1, 1, 1), (1, 1, 1), (1, 1, 1))
|
||||
strategy3 = ((16, 1, 1, 1), (16, 1, 1, 1))
|
||||
net = Net(_w, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_layer_norm_model_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1))
|
||||
strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1))
|
||||
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1))
|
||||
net = Net(_w, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_layer_norm_hybrid_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||
strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1))
|
||||
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||
net = Net(_w, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_layer_norm_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
|
||||
net = Net(_w)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_layer_norm_repeat_calc():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||
strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1))
|
||||
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||
net = Net(_w, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
Loading…
Reference in new issue