|
|
@ -1512,7 +1512,87 @@ Status ValidStageCheck(const std::vector<int32_t> &stages, int32_t strategy_stag
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
// find previous parallel care node.
|
|
|
|
|
|
|
|
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(unique_ids);
|
|
|
|
|
|
|
|
// if previous node is a parameter, handle it in the outsize.
|
|
|
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
CNodePtr cnode = node->cast<CNodePtr>();
|
|
|
|
|
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
|
|
|
|
|
|
|
if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
|
|
|
|
|
|
|
|
unique_ids->push_back(cnode->UniqueId());
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
bool find = false;
|
|
|
|
|
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
|
|
|
|
|
|
|
if (prim->name() == DEPEND && index != 1) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (FindPreNodes(cnode->inputs()[index], unique_ids)) {
|
|
|
|
|
|
|
|
find = true;
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return find;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, vector<std::string> *unique_ids) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(unique_ids);
|
|
|
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
|
|
|
if (prim->name() == RETURN) {
|
|
|
|
|
|
|
|
if (!FindPreNodes(cnode, unique_ids)) {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(operator_);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
|
|
|
StrategyPtr strategyPtr;
|
|
|
|
|
|
|
|
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strategy_v_ptr);
|
|
|
|
|
|
|
|
strategyPtr = NewStrategy(0, *strategy_v_ptr);
|
|
|
|
|
|
|
|
std::vector<ValuePtr> elements;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
|
|
|
|
|
|
|
|
elements.push_back(MakeValue((*strategy_v_ptr)[i]));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
|
|
|
|
|
|
|
// display the strategy generated by batch parallel
|
|
|
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
|
|
|
attrs[GEN_STRATEGY] = strategy;
|
|
|
|
|
|
|
|
(void)prim->SetAttrs(attrs);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString();
|
|
|
|
|
|
|
|
return strategyPtr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void SetLastNodeStrategy(const StrategyPtr strategyPtr) {
|
|
|
|
|
|
|
|
auto strategys = strategyPtr->GetInputDim();
|
|
|
|
|
|
|
|
for (size_t i = 0; i < strategys.size(); ++i) {
|
|
|
|
|
|
|
|
for (size_t j = 0; j < strategys[i].size(); ++j) {
|
|
|
|
|
|
|
|
strategys[i][j] = 1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
strategyPtr->ResetInputs(strategys);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) {
|
|
|
|
// load strategy map from checkpoint
|
|
|
|
// load strategy map from checkpoint
|
|
|
|
StrategyMap stra_map;
|
|
|
|
StrategyMap stra_map;
|
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
|
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
|
|
|
@ -1520,7 +1600,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
vector<std::string> last_forward_node_ids;
|
|
|
|
|
|
|
|
if (!is_training) {
|
|
|
|
|
|
|
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
|
|
|
|
|
|
|
}
|
|
|
|
// Get global rank after the checkpoint?
|
|
|
|
// Get global rank after the checkpoint?
|
|
|
|
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
|
|
|
|
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
|
|
|
|
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
|
|
|
|
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
|
|
|
@ -1572,30 +1656,22 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
bool load_strategy_from_ckpt =
|
|
|
|
bool load_strategy_from_ckpt =
|
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
|
|
|
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
|
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
|
|
|
|
|
|
|
last_forward_node_ids.end();
|
|
|
|
|
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
|
|
|
|
|
|
|
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
|
|
|
|
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
|
|
|
|
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
|
|
|
|
<< " is empty, using batch parallel";
|
|
|
|
<< " is empty, using batch parallel";
|
|
|
|
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
|
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
|
|
|
|
if (strategy_v_ptr == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<ValuePtr> elements;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
|
|
|
|
|
|
|
|
elements.push_back(MakeValue((*strategy_v_ptr)[i]));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
|
|
|
|
|
|
|
// display the strategy generated by batch parallel
|
|
|
|
|
|
|
|
attrs[GEN_STRATEGY] = strategy;
|
|
|
|
|
|
|
|
(void)prim->SetAttrs(attrs);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is "
|
|
|
|
|
|
|
|
<< attrs[GEN_STRATEGY]->ToString();
|
|
|
|
|
|
|
|
strategyPtr = NewStrategy(0, *strategy_v_ptr);
|
|
|
|
|
|
|
|
} else if (load_strategy_from_ckpt) {
|
|
|
|
} else if (load_strategy_from_ckpt) {
|
|
|
|
strategyPtr = stra_map[strategy_key_name];
|
|
|
|
strategyPtr = stra_map[strategy_key_name];
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
strategyPtr = ExtractStrategy(attrs);
|
|
|
|
strategyPtr = ExtractStrategy(attrs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (strategyPtr != nullptr) {
|
|
|
|
if (strategyPtr != nullptr) {
|
|
|
|
|
|
|
|
if (is_last_nodes && full_batch) {
|
|
|
|
|
|
|
|
SetLastNodeStrategy(strategyPtr);
|
|
|
|
|
|
|
|
}
|
|
|
|
(*operator_).set_stage_id(strategyPtr->GetInputStage());
|
|
|
|
(*operator_).set_stage_id(strategyPtr->GetInputStage());
|
|
|
|
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id();
|
|
|
|
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id();
|
|
|
|
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) {
|
|
|
|
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) {
|
|
|
@ -2854,7 +2930,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// extract shape and strategy, set operator_info
|
|
|
|
// extract shape and strategy, set operator_info
|
|
|
|
ExtractInformation(all_nodes);
|
|
|
|
ExtractInformation(all_nodes, root->has_flag(TRAINING));
|
|
|
|
ReshapeInit(all_nodes);
|
|
|
|
ReshapeInit(all_nodes);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|