|
|
|
@ -29,9 +29,8 @@
|
|
|
|
|
namespace ge {
|
|
|
|
|
namespace formats {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr int64_t kCubeN = 16;
|
|
|
|
|
constexpr int64_t kDim = 1;
|
|
|
|
|
|
|
|
|
|
constexpr int64_t kCubeN = 16;
|
|
|
|
|
static int64_t Measure(int64_t x, int64_t y) {
|
|
|
|
|
int64_t z = y;
|
|
|
|
|
while (x % y != 0) {
|
|
|
|
@ -266,7 +265,7 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,
|
|
|
|
|
"groups are %ld %ld %ld",cin_ori, cout_ori, groups);
|
|
|
|
|
return GRAPH_FAILED;
|
|
|
|
|
}
|
|
|
|
|
const int64_t cube_k = GetCubeSizeByDataType(data_type);
|
|
|
|
|
const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type);
|
|
|
|
|
int64_t e_mult = std::min(
|
|
|
|
|
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)),
|
|
|
|
|
groups);
|
|
|
|
@ -277,16 +276,18 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,
|
|
|
|
|
int64_t dim_cin = cin_opt / cube_k;
|
|
|
|
|
int64_t data_size = GetSizeByDataType(args.src_data_type);
|
|
|
|
|
int64_t size_output_data = g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size;
|
|
|
|
|
GE_CHK_BOOL_EXEC_NOLOG(size_output_data != 0, result.length = static_cast<size_t>(size_output_data);
|
|
|
|
|
return SUCCESS;);
|
|
|
|
|
if(size_output_data == 0){
|
|
|
|
|
result.length = static_cast<size_t>(size_output_data);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
errno_t ret = EOK;
|
|
|
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>());
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
|
|
|
|
|
dst == nullptr,
|
|
|
|
|
if (dst == nullptr) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data);
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;);
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(),
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data);
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;
|
|
|
|
|
}
|
|
|
|
|
ret = memset_s(dst.get(), static_cast<size_t>(size_output_data), 0, static_cast<size_t>(size_output_data));
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory, ret is %d", ret);
|
|
|
|
|