fix layernorm bug

pull/696/head
yangzhenzhang 5 years ago
parent 348b0ef53c
commit 4750861054

@ -69,7 +69,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
// check input strategy // check input strategy
for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) { for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) {
if (input_strategy[begin_norm_axis_] != NO_SPLIT_STRATEGY) { if (input_strategy[i] != NO_SPLIT_STRATEGY) {
MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy); MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy);
return FAILED; return FAILED;
} }

@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore as ms import mindspore as ms
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, TrainOneStepCell, Momentum from mindspore.nn import Cell, TrainOneStepCell, Momentum
@ -24,7 +25,7 @@ from mindspore.common.initializer import initializer
class Net(Cell): class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None): def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
super().__init__() super().__init__()
self.begin_norm_axis = -1 self.begin_norm_axis = 2
self.begin_params_axis = 1 self.begin_params_axis = 1
self.mul = P.Mul().set_strategy(strategy1) self.mul = P.Mul().set_strategy(strategy1)
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2) self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
@ -64,18 +65,18 @@ def test_layer_norm_data_parallel():
def test_layer_norm_model_parallel(): def test_layer_norm_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1)) strategy1 = ((1, 16, 1, 1), (1, 16, 1, 1))
strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1)) strategy2 = ((1, 16, 1, 1), (16, 1, 1), (16, 1, 1))
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1)) strategy3 = ((1, 16, 1, 1), (1, 16, 1, 1))
net = Net(_w, strategy1, strategy2, strategy3) net = Net(_w, strategy1, strategy2, strategy3)
compile(net) compile(net)
def test_layer_norm_hybrid_parallel(): def test_layer_norm_hybrid_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy1 = ((2, 8, 1, 1), (2, 8, 1, 1))
strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1)) strategy2 = ((2, 8, 1, 1), (8, 1, 1), (8, 1, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy3 = ((2, 8, 1, 1), (2, 8, 1, 1))
net = Net(_w, strategy1, strategy2, strategy3) net = Net(_w, strategy1, strategy2, strategy3)
compile(net) compile(net)
@ -89,8 +90,17 @@ def test_layer_norm_auto_parallel():
def test_layer_norm_repeat_calc(): def test_layer_norm_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1)) strategy2 = ((2, 2, 1, 1), (2, 1, 1), (2, 1, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3) net = Net(_w, strategy1, strategy2, strategy3)
compile(net) compile(net)
def test_layer_norm_wrong_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((1, 2, 1, 2), (2, 1, 2), (2, 1, 2))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile(net)

Loading…
Cancel
Save