|
|
@ -33,12 +33,27 @@ namespace mindspore {
|
|
|
|
namespace parallel {
|
|
|
|
namespace parallel {
|
|
|
|
Status ReLUV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
|
|
|
Status ReLUV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
|
|
|
|
|
|
|
|
|
|
|
Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
|
|
|
Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Strategys stra = strategy->GetInputDim();
|
|
|
|
|
|
|
|
Dimensions input_strategy = stra.at(0);
|
|
|
|
|
|
|
|
if (input_strategy[1] != 1) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << name_ << "The second dimension is not splitable.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Status ReLUV2Info::GetAttrs() { return SUCCESS; }
|
|
|
|
Status ReLUV2Info::GetAttrs() { return SUCCESS; }
|
|
|
|
|
|
|
|
|
|
|
|
Status ReLUV2Info::GenerateStrategies(int32_t stage_id) {
|
|
|
|
Status ReLUV2Info::GenerateStrategies(int32_t stage_id) {
|
|
|
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
|
|
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
|
|
|
|
|
|
|
// the second dimension is not splitable
|
|
|
|
|
|
|
|
input0_split[1] = 0;
|
|
|
|
Shapes splittable_inputs = {input0_split};
|
|
|
|
Shapes splittable_inputs = {input0_split};
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<StrategyPtr> sp_vector;
|
|
|
|
std::vector<StrategyPtr> sp_vector;
|
|
|
@ -97,6 +112,7 @@ Status ReLUV2Info::InferForwardCommunication() {
|
|
|
|
|
|
|
|
|
|
|
|
Status ReLUV2Info::InferTensorMap() {
|
|
|
|
Status ReLUV2Info::InferTensorMap() {
|
|
|
|
Shape tensor_map_index;
|
|
|
|
Shape tensor_map_index;
|
|
|
|
|
|
|
|
Shape tensor_map_mask;
|
|
|
|
size_t size = inputs_shape_.at(0).size();
|
|
|
|
size_t size = inputs_shape_.at(0).size();
|
|
|
|
// such as 4: tensor_map_index [3,2,1,0]
|
|
|
|
// such as 4: tensor_map_index [3,2,1,0]
|
|
|
|
for (size_t i = 0; i < size; ++i) {
|
|
|
|
for (size_t i = 0; i < size; ++i) {
|
|
|
@ -104,9 +120,12 @@ Status ReLUV2Info::InferTensorMap() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inputs_tensor_map_.push_back(tensor_map_index);
|
|
|
|
inputs_tensor_map_.push_back(tensor_map_index);
|
|
|
|
// output and mask
|
|
|
|
// output
|
|
|
|
outputs_tensor_map_.push_back(tensor_map_index);
|
|
|
|
|
|
|
|
outputs_tensor_map_.push_back(tensor_map_index);
|
|
|
|
outputs_tensor_map_.push_back(tensor_map_index);
|
|
|
|
|
|
|
|
tensor_map_mask = tensor_map_index;
|
|
|
|
|
|
|
|
// mask format NC1HWC0
|
|
|
|
|
|
|
|
tensor_map_mask.push_back(MAP_NONE);
|
|
|
|
|
|
|
|
outputs_tensor_map_.push_back(tensor_map_mask);
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -116,7 +135,7 @@ Status ReLUV2Info::InferTensorInfo() {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TensorLayout input_layout, output_layout;
|
|
|
|
TensorLayout input_layout, output_layout, mask_layout;
|
|
|
|
// infer tensor layout
|
|
|
|
// infer tensor layout
|
|
|
|
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
|
|
|
|
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
|
|
|
@ -129,10 +148,15 @@ Status ReLUV2Info::InferTensorInfo() {
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (mask_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[1], outputs_shape_[1]) != SUCCESS) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
TensorInfo output_tensor_info(output_layout);
|
|
|
|
TensorInfo output_tensor_info(output_layout);
|
|
|
|
|
|
|
|
TensorInfo mask_tensor_info(mask_layout);
|
|
|
|
// output and mask
|
|
|
|
// output and mask
|
|
|
|
outputs_tensor_info_.push_back(output_tensor_info);
|
|
|
|
outputs_tensor_info_.push_back(output_tensor_info);
|
|
|
|
outputs_tensor_info_.push_back(output_tensor_info);
|
|
|
|
outputs_tensor_info_.push_back(mask_tensor_info);
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|