|
|
|
@ -50,21 +50,21 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{
|
|
|
|
|
bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) {
|
|
|
|
|
if (src_shape.empty()) {
|
|
|
|
|
std::string error = "Failed to transpose, empty src shape";
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape");
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str());
|
|
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to transpose, empty src shape");
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto dim : src_shape) {
|
|
|
|
|
if (dim < 0) {
|
|
|
|
|
std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape));
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (perm_arg.size() != src_shape.size()) {
|
|
|
|
|
std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) +
|
|
|
|
|
" and perm arg" + FmtToStr(perm_arg.size()) + " are different";
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -73,7 +73,7 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in
|
|
|
|
|
if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) {
|
|
|
|
|
std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) +
|
|
|
|
|
", perm arg " + FmtToStr(JoinToString(perm_arg));
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
|
|
|
|
|
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -82,11 +82,11 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in
|
|
|
|
|
bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type,
|
|
|
|
|
const std::vector<int64_t> &perm_arg) {
|
|
|
|
|
if (src == nullptr) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to transpose, the src is null");
|
|
|
|
|
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to transpose, the src is null");
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (GetSizeByDataType(src_data_type) < 0) {
|
|
|
|
|
GELOGE(UNSUPPORTED, "Failed to transpose, the data type %s is not support",
|
|
|
|
|
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to transpose, the data type %s is not support",
|
|
|
|
|
TypeUtils::DataTypeToSerialString(src_data_type).c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|