|
|
|
@ -32,7 +32,7 @@ namespace parallel {
|
|
|
|
|
* prelu has 2 input
|
|
|
|
|
* A: A float tensor of shape [NCHW] representing the output of the preview layer.
|
|
|
|
|
* w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input.
|
|
|
|
|
* the strategy of w should equal to the channel dimension of strategy of A
|
|
|
|
|
* the strategy of w should equal to the channel dimension of strategy of A, or equal to 1
|
|
|
|
|
*/
|
|
|
|
|
Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
|
|
|
@ -52,7 +52,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
}
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0]) {
|
|
|
|
|
if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Invalid channel strategy.";
|
|
|
|
|
} else {
|
|
|
|
@ -107,7 +107,11 @@ Status PReLUInfo::InferTensorMap() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorMap param_tensor_map;
|
|
|
|
|
param_tensor_map.push_back(input_tensor_map.at(1));
|
|
|
|
|
if (inputs_shape_[1][0] == 1) {
|
|
|
|
|
param_tensor_map.push_back(-1);
|
|
|
|
|
} else {
|
|
|
|
|
param_tensor_map.push_back(input_tensor_map.at(1));
|
|
|
|
|
}
|
|
|
|
|
inputs_tensor_map_.push_back(input_tensor_map);
|
|
|
|
|
inputs_tensor_map_.push_back(param_tensor_map);
|
|
|
|
|
outputs_tensor_map_.push_back(input_tensor_map);
|
|
|
|
|