diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc index fed361616b..14483e97a1 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc @@ -32,7 +32,7 @@ namespace parallel { * prelu has 2 input * A: A float tensor of shape [NCHW] representing the output of the preview layer. * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. - * the strategy of w should equal to the channel dimension of strategy of A + * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 */ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { @@ -52,7 +52,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { } return FAILED; } - if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0]) { + if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; } else { @@ -107,7 +107,11 @@ Status PReLUInfo::InferTensorMap() { } TensorMap param_tensor_map; - param_tensor_map.push_back(input_tensor_map.at(1)); + if (inputs_shape_[1][0] == 1) { + param_tensor_map.push_back(-1); + } else { + param_tensor_map.push_back(input_tensor_map.at(1)); + } inputs_tensor_map_.push_back(input_tensor_map); inputs_tensor_map_.push_back(param_tensor_map); outputs_tensor_map_.push_back(input_tensor_map); diff --git a/tests/ut/python/parallel/test_prelu.py b/tests/ut/python/parallel/test_prelu.py index d3ad1cc710..5638c9cdbd 100755 --- a/tests/ut/python/parallel/test_prelu.py +++ b/tests/ut/python/parallel/test_prelu.py @@ -166,3 +166,21 @@ def test_prelu_parallel_success4(): w = Tensor(np.random.rand(16),dtype=ms.float32) net = GradWrap(NetWithLoss(Net(strategy))) _executor.compile(net, x, w) + +def test_prelu_parallel_success5(): + class Net(nn.Cell): + def __init__(self, strategy): + super().__init__() + self.prelu = P.PReLU().set_strategy(strategy) + def construct(self, x, y): + out = self.prelu(x, y) + return out + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=64, global_rank=0) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + strategy = ((2, 4, 4, 2), (1, )) + x = Tensor(np.random.rand(4, 16, 32, 64),dtype=ms.float32) + w = Tensor(np.random.rand(1),dtype=ms.float32) + net = GradWrap(NetWithLoss(Net(strategy))) + _executor.compile(net, x, w) +