!8063 fix ReLUV2 mask error

Merge pull request !8063 from yihuaijie/dev
pull/8063/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 33aa2ae16b

@ -33,12 +33,27 @@ namespace mindspore {
namespace parallel {
Status ReLUV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
}
Strategys stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0);
if (input_strategy[1] != 1) {
MS_LOG(ERROR) << name_ << "The second dimension is not splitable.";
return FAILED;
}
return SUCCESS;
}
Status ReLUV2Info::GetAttrs() { return SUCCESS; }
Status ReLUV2Info::GenerateStrategies(int32_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
// the second dimension is not splitable
input0_split[1] = 0;
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
@ -97,6 +112,7 @@ Status ReLUV2Info::InferForwardCommunication() {
Status ReLUV2Info::InferTensorMap() {
Shape tensor_map_index;
Shape tensor_map_mask;
size_t size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) {
@ -104,9 +120,12 @@ Status ReLUV2Info::InferTensorMap() {
}
inputs_tensor_map_.push_back(tensor_map_index);
// output and mask
outputs_tensor_map_.push_back(tensor_map_index);
// output
outputs_tensor_map_.push_back(tensor_map_index);
tensor_map_mask = tensor_map_index;
// mask format NC1HWC0
tensor_map_mask.push_back(MAP_NONE);
outputs_tensor_map_.push_back(tensor_map_mask);
return SUCCESS;
}
@ -116,7 +135,7 @@ Status ReLUV2Info::InferTensorInfo() {
return FAILED;
}
TensorLayout input_layout, output_layout;
TensorLayout input_layout, output_layout, mask_layout;
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
@ -129,10 +148,15 @@ Status ReLUV2Info::InferTensorInfo() {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
if (mask_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[1], outputs_shape_[1]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
TensorInfo mask_tensor_info(mask_layout);
// output and mask
outputs_tensor_info_.push_back(output_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
outputs_tensor_info_.push_back(mask_tensor_info);
return SUCCESS;
}

@ -30,8 +30,8 @@
namespace mindspore {
namespace parallel {
/*
* The input, output and mask have the same tensormap.
* And all dimensions of input are splitable.
* The second dimension is not splitable, as mask is caculated along it.
* The input and output have the same tensormap (3, 2, 1, 0), mask's tensormap is (3, 2, 1, 0, -1)
*/
class ReLUV2Info : public OperatorInfo {
public:

Loading…
Cancel
Save