|
|
@ -24,6 +24,7 @@
|
|
|
|
#include "common/fp16_t.h"
|
|
|
|
#include "common/fp16_t.h"
|
|
|
|
#include "common/ge/ge_util.h"
|
|
|
|
#include "common/ge/ge_util.h"
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
|
|
|
|
#include "framework/common/debug/log.h"
|
|
|
|
#include "graph/utils/type_utils.h"
|
|
|
|
#include "graph/utils/type_utils.h"
|
|
|
|
#include "securec.h"
|
|
|
|
#include "securec.h"
|
|
|
|
|
|
|
|
|
|
|
@ -123,9 +124,9 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result
|
|
|
|
std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type);
|
|
|
|
std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type);
|
|
|
|
auto iter = trans_mode_map.find(trans_info);
|
|
|
|
auto iter = trans_mode_map.find(trans_info);
|
|
|
|
if (iter == trans_mode_map.end()) {
|
|
|
|
if (iter == trans_mode_map.end()) {
|
|
|
|
std::string error = "Failed to trans data from datatype [" +
|
|
|
|
std::string error = "Failed to trans data from datatype " +
|
|
|
|
TypeUtils::FormatToSerialString(args.src_data_type) + "] to " + "[" +
|
|
|
|
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " +
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_data_type) + "], it is not supported.";
|
|
|
|
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported.";
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
|
|
|
|
return UNSUPPORTED;
|
|
|
|
return UNSUPPORTED;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -133,14 +134,14 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result
|
|
|
|
|
|
|
|
|
|
|
|
int size = GetSizeByDataType(args.dst_data_type);
|
|
|
|
int size = GetSizeByDataType(args.dst_data_type);
|
|
|
|
if (size <= 0) {
|
|
|
|
if (size <= 0) {
|
|
|
|
std::string error = "Failed to calc size from data type[" +
|
|
|
|
std::string error = "Failed to calc size from data type" +
|
|
|
|
TypeUtils::DataTypeToSerialString(args.dst_data_type) + "], it is not supported.";
|
|
|
|
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported.";
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
return PARAM_INVALID;
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) {
|
|
|
|
if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) {
|
|
|
|
std::string error = "args.src_data_size[" + std::to_string(args.src_data_size) +
|
|
|
|
std::string error = "args.src_data_size" + FmtToStr(args.src_data_size) +
|
|
|
|
"] or data type size[" + std::to_string(size) + "is too big";
|
|
|
|
" or data type size" + FmtToStr(std::to_string(size)) + " is too big";
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
return PARAM_INVALID;
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -158,9 +159,10 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) {
|
|
|
|
if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) {
|
|
|
|
std::string error = "Failed to cast data from datatype [" +
|
|
|
|
std::string error = "Failed to cast data from datatype " +
|
|
|
|
TypeUtils::FormatToSerialString(args.src_data_type) + "] to " + "[" +
|
|
|
|
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " +
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_data_type) + "], data size is " + std::to_string(args.src_data_size);
|
|
|
|
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " +
|
|
|
|
|
|
|
|
FmtToStr(std::to_string(args.src_data_size));
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str());
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|