|
|
|
@ -32,7 +32,7 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
|
Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
|
|
|
|
|
Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
|
|
|
|
|
auto manual_split_without_offset_iter = attrs_.find("manual_split");
|
|
|
|
|
if (manual_split_without_offset_iter != attrs_.end()) {
|
|
|
|
|
manual_split_ = true;
|
|
|
|
@ -68,7 +68,7 @@ Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::GetManualSplitAttr() {
|
|
|
|
|
Status GatherPInfo::GetManualSplitAttr() {
|
|
|
|
|
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
|
|
|
|
|
if (manual_split_with_offset_iter != attrs_.end()) {
|
|
|
|
|
manual_split_ = true;
|
|
|
|
@ -118,7 +118,7 @@ Status GatherV2PInfo::GetManualSplitAttr() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::GetAttrs() {
|
|
|
|
|
Status GatherPInfo::GetAttrs() {
|
|
|
|
|
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
|
|
|
|
if (target_ != CPU) {
|
|
|
|
|
if (input_value_.at(2) == nullptr) {
|
|
|
|
@ -172,7 +172,7 @@ Status GatherV2PInfo::GetAttrs() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
|
|
|
|
Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
|
|
|
|
|
if (strategy.size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
|
|
|
|
|
return FAILED;
|
|
|
|
@ -228,7 +228,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
@ -306,7 +306,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferMirrorOps() {
|
|
|
|
|
Status GatherPInfo::InferMirrorOps() {
|
|
|
|
|
// There is no mirror operators for manual split
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -336,7 +336,7 @@ Status GatherV2PInfo::InferMirrorOps() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferDevMatrixShape() {
|
|
|
|
|
Status GatherPInfo::InferDevMatrixShape() {
|
|
|
|
|
dev_matrix_shape_.clear();
|
|
|
|
|
out_dev_matrix_shape_.clear();
|
|
|
|
|
// infer input dev_matrix_shape
|
|
|
|
@ -386,7 +386,7 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GatherV2PInfo::InferInputsTensorMap() {
|
|
|
|
|
void GatherPInfo::InferInputsTensorMap() {
|
|
|
|
|
// infer input tensor map
|
|
|
|
|
// param_strategy(axis) != 1
|
|
|
|
|
size_t param_size = inputs_shape_.at(0).size();
|
|
|
|
@ -413,7 +413,7 @@ void GatherV2PInfo::InferInputsTensorMap() {
|
|
|
|
|
inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GatherV2PInfo::InferOutputsTensorMap() {
|
|
|
|
|
void GatherPInfo::InferOutputsTensorMap() {
|
|
|
|
|
// infer output tensor map
|
|
|
|
|
size_t param_size = inputs_shape_.at(0).size();
|
|
|
|
|
size_t index_size = inputs_shape_.at(1).size();
|
|
|
|
@ -460,7 +460,7 @@ void GatherV2PInfo::InferOutputsTensorMap() {
|
|
|
|
|
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferTensorMap() {
|
|
|
|
|
Status GatherPInfo::InferTensorMap() {
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
inputs_tensor_map_.push_back({1, 0});
|
|
|
|
|
inputs_tensor_map_.push_back({-1, 1});
|
|
|
|
@ -472,7 +472,7 @@ Status GatherV2PInfo::InferTensorMap() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferTensorInfo() {
|
|
|
|
|
Status GatherPInfo::InferTensorInfo() {
|
|
|
|
|
// infer tensor shape
|
|
|
|
|
Shape input_shape = inputs_shape_.at(0);
|
|
|
|
|
Shape input_index_shape = inputs_shape_.at(1);
|
|
|
|
@ -505,7 +505,7 @@ Status GatherV2PInfo::InferTensorInfo() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferBias() {
|
|
|
|
|
Status GatherPInfo::InferBias() {
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
int64_t rank = g_device_manager->rank_index_in_stage();
|
|
|
|
|
auto input_shape = inputs_shape_.at(0);
|
|
|
|
@ -559,7 +559,7 @@ Status GatherV2PInfo::InferBias() {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferOffset() {
|
|
|
|
|
Status GatherPInfo::InferOffset() {
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
size_t rank = g_device_manager->rank_index_in_stage();
|
|
|
|
|
|
|
|
|
@ -580,7 +580,7 @@ Status GatherV2PInfo::InferOffset() {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferGroup() {
|
|
|
|
|
Status GatherPInfo::InferGroup() {
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|
size_t dim = LongToSize(axis_);
|
|
|
|
|
if (param_strategy.at(LongToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
|
|
|
|
@ -610,7 +610,7 @@ Status GatherV2PInfo::InferGroup() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
Status GatherPInfo::InferForwardCommunication() {
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
@ -647,7 +647,7 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
|
Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
|
GenerateGraph gen_g = GenerateGraph();
|
|
|
|
|
if (gen_g.Init(cnode) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
|
|
|
@ -705,7 +705,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
|
ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
|
if (manual_split_ && target_ != CPU) {
|
|
|
|
|
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
|
|
|
@ -724,7 +724,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
|
return replace_graph_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::ComputeReplaceOp() {
|
|
|
|
|
Status GatherPInfo::ComputeReplaceOp() {
|
|
|
|
|
int64_t bias = 0;
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
if (InferOffset() != SUCCESS) {
|
|
|
|
@ -752,7 +752,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
|
|
|
|
|
Status GatherPInfo::Init(const StrategyPtr &strategy) {
|
|
|
|
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Init failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
@ -765,7 +765,7 @@ Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|
|
|
|
Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|
|
|
|
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
|
|
|
|
@ -783,9 +783,9 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
|
|
|
|
Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::GenerateStrategies(int64_t stage_id) {
|
|
|
|
|
Status GatherPInfo::GenerateStrategies(int64_t stage_id) {
|
|
|
|
|
if (GetAttrs() != SUCCESS) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
@ -814,7 +814,7 @@ Status GatherV2PInfo::GenerateStrategies(int64_t stage_id) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
|
|
|
|
|
std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() {
|
|
|
|
|
if (GetAttrs() != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
|
|
|
|
}
|
|
|
|
|