|
|
|
@ -86,9 +86,9 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
|
|
|
|
|
hw_shape.push_back(DIM_DEFAULT_VALUE);
|
|
|
|
|
hw_shape.push_back(src_shape[kNdDimIndexN]);
|
|
|
|
|
if (!IsShapeValid(dst_shape)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s",
|
|
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
|
|
|
|
|
ShapeToString(dst_shape).c_str());
|
|
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
default:
|
|
|
|
@ -106,9 +106,9 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
|
|
|
|
|
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]);
|
|
|
|
|
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]);
|
|
|
|
|
if (!IsShapeValid(dst_shape)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s",
|
|
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
|
|
|
|
|
ShapeToString(dst_shape).c_str());
|
|
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
@ -118,14 +118,14 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) {
|
|
|
|
|
ShapeVector expect_src_shape;
|
|
|
|
|
auto ret = TransShapeToFracZz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Trans shape from %s to %s, shape %s to %s, data type %s failed",
|
|
|
|
|
GELOGE(ret, "Trans shape from %s to %s, shape %s to %s, data type %s failed",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), ShapeToString(args.dst_shape).c_str(),
|
|
|
|
|
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
if (!IsTransShapeSrcCorrect(args, expect_src_shape)) {
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
@ -140,10 +140,10 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>());
|
|
|
|
|
if (dst == nullptr) {
|
|
|
|
|
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
|
|
|
|
|
return OUT_OF_MEMORY;
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;
|
|
|
|
|
}
|
|
|
|
|
// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D
|
|
|
|
|
auto times = hw_shape.at(kNdDimIndexN);
|
|
|
|
@ -179,8 +179,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
|
|
|
|
|
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
|
|
|
|
static_cast<size_t>(size * w0));
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto w1_head = num_w1 * w0;
|
|
|
|
@ -195,8 +195,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
|
|
|
|
|
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
|
|
|
|
static_cast<size_t>(size));
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -217,10 +217,10 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>());
|
|
|
|
|
if (dst == nullptr) {
|
|
|
|
|
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
|
|
|
|
|
return OUT_OF_MEMORY;
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D
|
|
|
|
@ -257,8 +257,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
|
|
|
|
|
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
|
|
|
|
static_cast<size_t>(size * w0));
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto w1_head = num_w1 * w0;
|
|
|
|
@ -273,8 +273,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
|
|
|
|
|
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
|
|
|
|
static_cast<size_t>(size));
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret);
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -287,13 +287,19 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult &result) {
|
|
|
|
|
if (!IsDataTypeSupport(args.src_data_type) || !CheckShape(args.src_format, args.src_shape) ||
|
|
|
|
|
!IsShapeValid(args.dst_shape)) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
if (!IsDataTypeSupport(args.src_data_type)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(),
|
|
|
|
|
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
|
|
|
|
|
return ACL_ERROR_GE_DATATYPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(),
|
|
|
|
|
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
@ -306,7 +312,7 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult &
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
if (!IsTransShapeDstCorrect(args, expect_shape)) {
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
return TransFormatFromNdToFracZz(args, result, hw_shape);
|
|
|
|
|
}
|
|
|
|
@ -314,31 +320,38 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult &
|
|
|
|
|
Status FormatTransferFractalZz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type,
|
|
|
|
|
Format dst_format, ShapeVector &dst_shape) {
|
|
|
|
|
if (!IsDataTypeSupport(data_type)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID,
|
|
|
|
|
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID,
|
|
|
|
|
"Not support trans format from %s to %s, src shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
|
|
|
|
|
ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str());
|
|
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_DATATYPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (!CheckShape(src_format, src_shape)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID,
|
|
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID,
|
|
|
|
|
"Not support trans format from %s to %s, src shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
|
|
|
|
|
ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str());
|
|
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
ShapeVector hw_shape;
|
|
|
|
|
return TransShapeToFracZz(src_shape, data_type, dst_shape, hw_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult &result) {
|
|
|
|
|
if (!IsDataTypeSupport(args.src_data_type) || !IsShapeValid(args.src_shape) ||
|
|
|
|
|
!CheckShape(args.dst_format, args.dst_shape)) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
if (!IsDataTypeSupport(args.src_data_type)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(),
|
|
|
|
|
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
|
|
|
|
|
return ACL_ERROR_GE_DATATYPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(),
|
|
|
|
|
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID;
|
|
|
|
|
}
|
|
|
|
|
GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
@ -346,8 +359,9 @@ Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult
|
|
|
|
|
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
|
|
|
|
|
|
|
|
|
|
ShapeVector hw_shape;
|
|
|
|
|
if (CheckShapeRelation(args, hw_shape) != SUCCESS) {
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
Status ret = CheckShapeRelation(args, hw_shape);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
return TransFormatFromFracZzToNd(args, result, hw_shape);
|
|
|
|
|
}
|
|
|
|
@ -356,7 +370,7 @@ Status FormatTransferFractalZzND::TransShape(Format src_format, const ShapeVecto
|
|
|
|
|
Format dst_format, ShapeVector &dst_shape) {
|
|
|
|
|
GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported",
|
|
|
|
|
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str());
|
|
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID;
|
|
|
|
|
return ACL_ERROR_GE_FORMAT_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZz, FORMAT_ND, FORMAT_FRACTAL_ZZ)
|
|
|
|
|