|
|
|
@ -48,28 +48,31 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{
|
|
|
|
|
|
|
|
|
|
bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) {
|
|
|
|
|
if (src_shape.empty()) {
|
|
|
|
|
std::string error = "Failed to transpose, empty src shape";
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape");
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto dim : src_shape) {
|
|
|
|
|
if (dim < 0) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to transpose, negative dim in src shape %s", ShapeToString(src_shape).c_str());
|
|
|
|
|
std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape));
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (perm_arg.size() != src_shape.size()) {
|
|
|
|
|
GELOGE(PARAM_INVALID,
|
|
|
|
|
"Failed to transpose, the size of src shape(%zu) and"
|
|
|
|
|
" perm arg(%zu) are different",
|
|
|
|
|
src_shape.size(), perm_arg.size());
|
|
|
|
|
std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) +
|
|
|
|
|
" and perm arg" + FmtToStr(perm_arg.size()) + " are different";
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> exists(perm_arg.size());
|
|
|
|
|
for (auto perm : perm_arg) {
|
|
|
|
|
if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to transpose, duplicated perm arg %ld, perm arg %s", perm,
|
|
|
|
|
JoinToString(perm_arg).c_str());
|
|
|
|
|
std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) +
|
|
|
|
|
", perm arg " + FmtToStr(JoinToString(perm_arg));
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -192,9 +195,10 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> &
|
|
|
|
|
}
|
|
|
|
|
auto expected_shape = TransShapeByPerm(src_shape, perm_arg);
|
|
|
|
|
if (dst_shape != expected_shape) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to trans axis for perm_arg %s, invalid dst shape %s, expect %s",
|
|
|
|
|
ShapeToString(perm_arg).c_str(), ShapeToString(dst_shape).c_str(), ShapeToString(expected_shape).c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
std::string error = "Failed to trans axis for perm_arg" +
|
|
|
|
|
FmtToStr(ShapeToString(perm_arg)) + ", invalid dst shape" +
|
|
|
|
|
FmtToStr(ShapeToString(dst_shape)) + ", expect" + FmtToStr(ShapeToString(expected_shape));
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return Transpose(data, src_shape, src_data_type, perm_arg, result);
|
|
|
|
@ -203,14 +207,18 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> &
|
|
|
|
|
Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) {
|
|
|
|
|
auto dst_iter = perm_args.find(src_format);
|
|
|
|
|
if (dst_iter == perm_args.end()) {
|
|
|
|
|
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str());
|
|
|
|
|
std::string error = "Failed to trans shape, do not support transpose from format " +
|
|
|
|
|
FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " +
|
|
|
|
|
FmtToStr(TypeUtils::FormatToSerialString(dst_format));
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
|
|
|
|
|
return UNSUPPORTED;
|
|
|
|
|
}
|
|
|
|
|
auto iter = dst_iter->second.find(dst_format);
|
|
|
|
|
if (iter == dst_iter->second.end()) {
|
|
|
|
|
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str());
|
|
|
|
|
std::string error = "Failed to trans shape, do not support transpose from format " +
|
|
|
|
|
FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " +
|
|
|
|
|
FmtToStr(TypeUtils::FormatToSerialString(dst_format));
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
|
|
|
|
|
return UNSUPPORTED;
|
|
|
|
|
}
|
|
|
|
|
perm = iter->second;
|
|
|
|
@ -223,11 +231,7 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult &
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
if (args.dst_shape != expected_shape) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, invalid dst shape %s, expect %s",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(),
|
|
|
|
|
ShapeToString(expected_shape).c_str());
|
|
|
|
|
if (!IsTransShapeDstCorrect(args, expected_shape)) {
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|