|
|
|
@ -35,18 +35,24 @@ class RtContextSwitchGuard {
|
|
|
|
|
RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) {
|
|
|
|
|
auto ret = rtCtxGetCurrent(&last_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Call rtCtxGetCurrent failed, device_id:%u, ret:0x%X, when %s",
|
|
|
|
|
device_id, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ret = rtCtxCreate(¤t_, mode, static_cast<int32_t>(device_id));
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, device_id:%u, ret:0x%X, when %s",
|
|
|
|
|
device_id, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ret = rtCtxSetCurrent(current_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, device_id:%u, ret:0x%X, when %s",
|
|
|
|
|
device_id, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -72,6 +78,8 @@ class RtContextSwitchGuard {
|
|
|
|
|
int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) {
|
|
|
|
|
int64_t var_size = GetSizeByDataType(desc.GetDataType());
|
|
|
|
|
if (var_size <= 0) {
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Data type:%s in desc, it's size:%ld < 0, check invalid when %s",
|
|
|
|
|
TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str(), var_size, __FUNCTION__);
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s",
|
|
|
|
|
TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str());
|
|
|
|
|
return -1;
|
|
|
|
@ -89,6 +97,8 @@ Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_res
|
|
|
|
|
auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast<void *>(trans_result.data.get()),
|
|
|
|
|
trans_result.length, RT_MEMCPY_HOST_TO_DEVICE);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, op:%s(%s), size:%lu, ret:0x%X, when %s", var->GetName().c_str(),
|
|
|
|
|
var->GetType().c_str(), trans_result.length, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length);
|
|
|
|
|
return RT_FAILED;
|
|
|
|
|
}
|
|
|
|
@ -110,6 +120,8 @@ Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_pt
|
|
|
|
|
|
|
|
|
|
uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
|
|
|
|
|
if (var_addr == nullptr) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, op:%s(%s), session_id:%lu, when %s",
|
|
|
|
|
RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id, __FUNCTION__);
|
|
|
|
|
GELOGE(INTERNAL_ERROR,
|
|
|
|
|
"Failed to copy var %s from device, cant not get "
|
|
|
|
|
"var addr from logic addr %p",
|
|
|
|
@ -124,6 +136,8 @@ Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_pt
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<uint8_t[]> var_host(new(std::nothrow) uint8_t[var_size_bytes]);
|
|
|
|
|
if (var_host == nullptr) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "New host memory failed, size:%ld, op:%s(%s), session_id:%lu, when %s",
|
|
|
|
|
var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, __FUNCTION__);
|
|
|
|
|
GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes);
|
|
|
|
|
return OUT_OF_MEMORY;
|
|
|
|
|
}
|
|
|
|
@ -131,6 +145,8 @@ Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_pt
|
|
|
|
|
ret = rtMemcpy(reinterpret_cast<void *>(var_host.get()), var_size_bytes, reinterpret_cast<void *>(var_addr),
|
|
|
|
|
var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%ld, op:%s(%s), session_id:%lu, ret:0x%X when %s",
|
|
|
|
|
var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(RT_FAILED,
|
|
|
|
|
"Failed to copy var memory from device, var %s, size %ld,"
|
|
|
|
|
" rt-error-code %u",
|
|
|
|
@ -175,6 +191,12 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats
|
|
|
|
|
TypeUtils::DataTypeToSerialString(data_type).c_str());
|
|
|
|
|
auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Trans format from %s to %s, shape %s to %s failed, data type:%s, ret:%u, when %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(dst_format).c_str(),
|
|
|
|
|
formats::ShapeToString(src_shape).c_str(),
|
|
|
|
|
formats::ShapeToString(dst_shape).c_str(),
|
|
|
|
|
TypeUtils::DataTypeToSerialString(data_type).c_str(), ret, __FUNCTION__);
|
|
|
|
|
GELOGE(INTERNAL_ERROR,
|
|
|
|
|
"Failed to trans format from %s to %s, shape %s to %s, "
|
|
|
|
|
"data type %s error code %u",
|
|
|
|
@ -195,6 +217,10 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats
|
|
|
|
|
auto ret = formats::TransDataType({src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type},
|
|
|
|
|
tmp_result);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Trans data type from %s to %s failed, input shape %s, data size %ld, ret:%u, "
|
|
|
|
|
"when %s", TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
|
|
|
|
|
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(),
|
|
|
|
|
formats::ShapeToString(input_shape).c_str(), src_data_size, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u",
|
|
|
|
|
TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
|
|
|
|
|
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
|
|
|
|
@ -202,6 +228,8 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Trans var data failed, the trans type %s does not supported, check invalid when %s",
|
|
|
|
|
trans_info.node_type.c_str(), __FUNCTION__);
|
|
|
|
|
GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported",
|
|
|
|
|
trans_info.node_type.c_str());
|
|
|
|
|
return UNSUPPORTED;
|
|
|
|
@ -236,6 +264,8 @@ Status ReAssignVarAddr(uint64_t session_id,
|
|
|
|
|
|
|
|
|
|
uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
|
|
|
|
|
if (var_addr == nullptr) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, var_name:%s, session_id:%lu, when %s",
|
|
|
|
|
RT_MEMORY_HBM, var_name.c_str(), session_id, __FUNCTION__);
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str());
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
}
|
|
|
|
@ -263,6 +293,8 @@ Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t
|
|
|
|
|
// Sync var data from device
|
|
|
|
|
std::unique_ptr<uint8_t[]> var_data;
|
|
|
|
|
if (trans_road.empty()) {
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Param trans_road is empty, session_id:%lu, check invalid when %s",
|
|
|
|
|
session_id, __FUNCTION__);
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty.");
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
}
|
|
|
|
@ -314,6 +346,10 @@ Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var
|
|
|
|
|
auto ret = formats::TransDataType(
|
|
|
|
|
{var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Trans data type from %s to %s failed, data size %ld, ret:%u, "
|
|
|
|
|
"when %s", TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(),
|
|
|
|
|
TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(),
|
|
|
|
|
src_data_shape_size, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "trans var data on host failed");
|
|
|
|
|
return ret;
|
|
|
|
|
});
|
|
|
|
@ -329,7 +365,10 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src,
|
|
|
|
|
/// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node.
|
|
|
|
|
/// need copy value from var_fp32 to var_fp16.
|
|
|
|
|
/// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr]
|
|
|
|
|
GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, GELOGE(FAILED, "node var is nullptr"); return FAILED);
|
|
|
|
|
GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr,
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Param var_src or var_dst is empty, session_id:%lu, device_id:%u, "
|
|
|
|
|
"check invalid when %s", session_id, device_id, __FUNCTION__);
|
|
|
|
|
GELOGE(FAILED, "node var is nullptr"); return FAILED);
|
|
|
|
|
// src_node output_desc (fp32)
|
|
|
|
|
GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0);
|
|
|
|
|
auto src_data_type = output_desc.GetDataType();
|
|
|
|
@ -447,15 +486,21 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::future<Status> f = executor.commit(
|
|
|
|
|
[](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id) -> Status {
|
|
|
|
|
[](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id,
|
|
|
|
|
const struct ErrorMessage::Context &error_context) -> Status {
|
|
|
|
|
ErrorManager::GetInstance().SetErrorContext(error_context);
|
|
|
|
|
rtError_t rt_ret = rtCtxSetCurrent(ctx);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, session_id:%lu, graph_id:%u, ret:0x%X, when %s",
|
|
|
|
|
session_id, graph_id, rt_ret, __FUNCTION__);
|
|
|
|
|
GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret);
|
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
|
}
|
|
|
|
|
uint32_t allocated_graph_id = 0;
|
|
|
|
|
Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Get allocated GraphId failed, session_id:%lu, graph_id:%u, ret:0x%X, when %s",
|
|
|
|
|
session_id, graph_id, ret, __FUNCTION__);
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(),
|
|
|
|
|
graph_id);
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
@ -480,7 +525,7 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
},
|
|
|
|
|
node, session_id, context, graph_id);
|
|
|
|
|
node, session_id, context, graph_id, ErrorManager::GetInstance().GetErrorContext());
|
|
|
|
|
if (!f.valid()) {
|
|
|
|
|
GELOGE(FAILED, "Future is invalid");
|
|
|
|
|
return FAILED;
|
|
|
|
@ -503,6 +548,8 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
|
|
|
|
|
Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id) {
|
|
|
|
|
GELOGD("CopyVarData start: session_id:%lu.", session_id);
|
|
|
|
|
if (compute_graph == nullptr) {
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, session_id:%lu, device_id:%u, check invalid when %s",
|
|
|
|
|
session_id, device_id, __FUNCTION__);
|
|
|
|
|
GELOGE(FAILED, "compute_graph is nullptr");
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|