|
|
|
@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
}
|
|
|
|
|
// load strategy checkpoint
|
|
|
|
|
// key of strategy map
|
|
|
|
|
std::string strategy_key_name = NodeParameterName(cnode);
|
|
|
|
|
std::string strategy_key_name = "";
|
|
|
|
|
auto param_names = NodeParameterName(cnode);
|
|
|
|
|
if (!param_names.empty()) {
|
|
|
|
|
strategy_key_name = param_names[0].first;
|
|
|
|
|
}
|
|
|
|
|
bool load_strategy_from_ckpt =
|
|
|
|
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
|
|
|
|
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
|
|
|
|
@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string NodeParameterName(const CNodePtr &node) {
|
|
|
|
|
std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) {
|
|
|
|
|
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
|
|
|
|
for (auto input : node_inputs) {
|
|
|
|
|
std::vector<std::pair<std::string, int>> param_names;
|
|
|
|
|
for (int i = 0; i < UintToInt(node_inputs.size()); ++i) {
|
|
|
|
|
auto input = node_inputs[i];
|
|
|
|
|
if (input->isa<Parameter>()) {
|
|
|
|
|
auto input_parameter = input->cast<ParameterPtr>();
|
|
|
|
|
if (input_parameter->has_default()) {
|
|
|
|
|
input_parameter->name();
|
|
|
|
|
if (ParameterRequireGrad(input_parameter)) {
|
|
|
|
|
param_names.push_back({input_parameter->name(), i});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return "";
|
|
|
|
|
return param_names;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
|
|
|
|
StrategyMap stra_map;
|
|
|
|
|
TensorInfoMap tensor_info_map;
|
|
|
|
|
ManualShapeMap manual_shape_map;
|
|
|
|
|
auto ret = func_graph->get_return();
|
|
|
|
|
auto all_nodes = DeepScopedGraphSearch(ret);
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::string param_name = NodeParameterName(cnode);
|
|
|
|
|
if (param_name.empty()) {
|
|
|
|
|
auto param_names = NodeParameterName(cnode);
|
|
|
|
|
if (param_names.empty()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
string param_name = param_names[0].first;
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
|
|
|
|
@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
|
|
|
|
StrategyPtr strategyPtr = operator_info->strategy();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node->scope());
|
|
|
|
|
stra_map[param_name] = strategyPtr;
|
|
|
|
|
for (auto param_name_pair : param_names) {
|
|
|
|
|
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1];
|
|
|
|
|
}
|
|
|
|
|
if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos ||
|
|
|
|
|
operator_info->name().find(GATHERV2) != std::string::npos) {
|
|
|
|
|
auto gatherv2_info = std::dynamic_pointer_cast<GatherV2PInfo>(operator_info);
|
|
|
|
|
auto param_split_shapes = gatherv2_info->param_split_shapes();
|
|
|
|
|
auto index_offsets = gatherv2_info->index_offsets();
|
|
|
|
|
if (param_split_shapes.size() != index_offsets.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same.";
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::pair<int32_t, int32_t>> manual_shape;
|
|
|
|
|
for (int i = 0; i < UintToInt(param_split_shapes.size()); ++i) {
|
|
|
|
|
manual_shape.push_back({param_split_shapes[i], index_offsets[i]});
|
|
|
|
|
}
|
|
|
|
|
manual_shape_map[param_name] = manual_shape;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|