|
|
|
@ -1523,7 +1523,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
std::string strategy_key_name = "";
|
|
|
|
|
auto param_names = NodeParameterName(cnode);
|
|
|
|
|
if (!param_names.empty()) {
|
|
|
|
|
strategy_key_name = param_names[0].first;
|
|
|
|
|
strategy_key_name = prim->name() + "_" + param_names[0].first;
|
|
|
|
|
}
|
|
|
|
|
bool load_strategy_from_ckpt =
|
|
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
|
|
|
@ -2214,9 +2214,23 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
|
|
|
|
|
auto input = node_inputs[i];
|
|
|
|
|
if (input->isa<Parameter>()) {
|
|
|
|
|
auto input_parameter = input->cast<ParameterPtr>();
|
|
|
|
|
if (input_parameter->has_default()) {
|
|
|
|
|
if (ParameterRequireGrad(input_parameter)) {
|
|
|
|
|
param_names.push_back({input_parameter->name(), i});
|
|
|
|
|
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
|
|
|
|
|
param_names.push_back({input_parameter->name(), i});
|
|
|
|
|
}
|
|
|
|
|
} else if (input->isa<CNode>()) {
|
|
|
|
|
CNodePtr cnode = input->cast<CNodePtr>();
|
|
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
return param_names;
|
|
|
|
|
}
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
|
|
|
|
if (prim->name() == CAST && cnode->inputs().size() >= 1) {
|
|
|
|
|
auto cast_input = cnode->inputs()[1];
|
|
|
|
|
if (cast_input->isa<Parameter>()) {
|
|
|
|
|
auto cast_input_parameter = cast_input->cast<ParameterPtr>();
|
|
|
|
|
if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) {
|
|
|
|
|
param_names.push_back({cast_input_parameter->name(), i});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -2224,14 +2238,11 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
|
|
|
|
|
return param_names;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
|
|
|
|
StrategyMap stra_map;
|
|
|
|
|
TensorInfoMap tensor_info_map;
|
|
|
|
|
ManualShapeMap manual_shape_map;
|
|
|
|
|
auto ret = func_graph->get_return();
|
|
|
|
|
auto all_nodes = DeepScopedGraphSearch(ret);
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
@ -2253,7 +2264,8 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
|
|
|
|
StrategyPtr strategyPtr = operator_info->strategy();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node->scope());
|
|
|
|
|
stra_map[param_name] = strategyPtr;
|
|
|
|
|
std::string stratey_key_name = prim->name() + "_" + param_name;
|
|
|
|
|
stra_map[stratey_key_name] = strategyPtr;
|
|
|
|
|
for (auto param_name_pair : param_names) {
|
|
|
|
|
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
|
|
|
|
continue;
|
|
|
|
@ -2547,7 +2559,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|
|
|
|
|
|
|
|
|
// save strategy as checkpoint for multi-train
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
|
|
|
|
CheckpointStrategy(root);
|
|
|
|
|
CheckpointStrategy(all_nodes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HandleSymbolicKeyInstance(root, all_nodes);
|
|
|
|
|