|
|
|
@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
// load strategy map from checkpoint
|
|
|
|
|
StrategyMap stra_map;
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
(void)cnode->set_operator_info(operator_);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (!StrategyFound(attrs)) {
|
|
|
|
|
// load strategy checkpoint
|
|
|
|
|
// key of strategy map
|
|
|
|
|
std::string instance_name = prim->instance_name();
|
|
|
|
|
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
|
|
|
|
|
bool load_strategy_from_ckpt =
|
|
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
|
|
|
|
|
|
|
|
|
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
|
|
|
|
|
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
|
|
|
|
|
<< " is empty, using batch parallel";
|
|
|
|
|
std::shared_ptr<std::vector<Dimensions>> strategy_v_ptr = operator_->GenerateBatchStrategies();
|
|
|
|
@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is "
|
|
|
|
|
<< attrs[GEN_STRATEGY]->ToString();
|
|
|
|
|
strategyPtr = NewStrategy(0, *strategy_v_ptr);
|
|
|
|
|
} else if (load_strategy_from_ckpt) {
|
|
|
|
|
strategyPtr = stra_map[strategy_key_name];
|
|
|
|
|
} else {
|
|
|
|
|
strategyPtr = ExtractStrategy(attrs);
|
|
|
|
|
}
|
|
|
|
@ -2022,53 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_LOG(INFO) << "Save strategy to checkpoint begin";
|
|
|
|
|
StrategyMap straMap;
|
|
|
|
|
auto ret = func_graph->get_return();
|
|
|
|
|
auto all_nodes = DeepScopedGraphSearch(ret);
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->operator_info();
|
|
|
|
|
if (operator_info) {
|
|
|
|
|
if (prim->instance_name().empty()) {
|
|
|
|
|
continue;
|
|
|
|
|
bool NodeWithParameter(const CNodePtr &node) {
|
|
|
|
|
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
|
|
|
|
for (auto input : node_inputs) {
|
|
|
|
|
if (input->isa<Parameter>()) {
|
|
|
|
|
auto input_parameter = input->cast<ParameterPtr>();
|
|
|
|
|
if (input_parameter->has_default()) {
|
|
|
|
|
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
|
|
|
|
|
}
|
|
|
|
|
std::string instance_name = prim->instance_name();
|
|
|
|
|
StrategyPtr strategyPtr = operator_info->strategy();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node->scope());
|
|
|
|
|
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
|
|
|
|
|
straMap[node_name] = strategyPtr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Save(straMap) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RestoreStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_LOG(INFO) << "Extract strategy from checkpoint begin";
|
|
|
|
|
StrategyMap straMap;
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Load(&straMap) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().RemoveCheckPoint() != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Remove strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
|
|
|
|
StrategyMap stra_map;
|
|
|
|
|
auto ret = func_graph->get_return();
|
|
|
|
|
auto all_nodes = DeepScopedGraphSearch(ret);
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
@ -2076,18 +2068,18 @@ void RestoreStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->operator_info();
|
|
|
|
|
if (operator_info) {
|
|
|
|
|
if (prim->instance_name().empty()) {
|
|
|
|
|
continue;
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name";
|
|
|
|
|
}
|
|
|
|
|
std::string instance_name = prim->instance_name();
|
|
|
|
|
StrategyPtr strategyPtr = operator_info->strategy();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node->scope());
|
|
|
|
|
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
|
|
|
|
|
MS_LOG(INFO) << "Node name is " << node_name;
|
|
|
|
|
if (straMap.find(node_name) != straMap.end()) {
|
|
|
|
|
StrategyPtr strategyPtr = straMap[node_name];
|
|
|
|
|
operator_info->set_strategy(strategyPtr);
|
|
|
|
|
}
|
|
|
|
|
stra_map[node_name] = strategyPtr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|
|
|
|
// extract shape and strategy, set operator_info
|
|
|
|
|
ExtractInformation(all_nodes);
|
|
|
|
|
ReshapeInit(all_nodes);
|
|
|
|
|
// extract strategy from checkpoint for multi-train
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().CheckPointOn() && StrategyCheckpoint::GetInstance().CheckPointExit()) {
|
|
|
|
|
RestoreStrategy(root);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// save strategy as checkpoint for multi-train
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().CheckPointOn() &&
|
|
|
|
|
StrategyCheckpoint::GetInstance().GetCurrentTrainTime() < StrategyCheckpoint::GetInstance().GetTrainTimes()) {
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
|
|
|
|
CheckpointStrategy(root);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|