|
|
|
@ -95,6 +95,9 @@ uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map<std::string, g
|
|
|
|
|
for (auto &item : params_list) {
|
|
|
|
|
std::string name = item.first;
|
|
|
|
|
std::shared_ptr<ge::Tensor> ge_tensor_ptr = std::make_shared<ge::Tensor>(item.second);
|
|
|
|
|
if (name.size() > 5 && name.compare(name.size() - 5, 5, "_temp") == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
TensorPtr tensor_ptr = GetMeTensorTransformed(graph_id, name, ge_tensor_ptr);
|
|
|
|
|
if (tensor_ptr == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Transform ge tensor to me tensor failed";
|
|
|
|
@ -104,6 +107,7 @@ uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map<std::string, g
|
|
|
|
|
param_dict["data"] = tensor_ptr;
|
|
|
|
|
parameter_list.append(param_dict);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
py::bool_ ret =
|
|
|
|
|
parse::python_adapter::CallPyFn(PYTHON_MOD_CALLBACK_MODULE, PYTHON_FUN_PROCESS_CHECKPOINT, parameter_list);
|
|
|
|
|
auto bool_ret = py::cast<bool>(ret);
|
|
|
|
|