From: @zhukun2020
Reviewed-by: 
Signed-off-by:
pull/1281/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 48ee563ebb

@ -29,6 +29,25 @@
namespace ge {
namespace formats {
namespace {
constexpr int64_t kDim = 1;
static int64_t Measure(int64_t x, int64_t y) {
int64_t z = y;
while (x % y != 0) {
z = x % y;
x = y;
y = z;
}
return z;
}
// least common multiple
static int64_t Lcm(int64_t a, int64_t b) {
if (b == 0) {
return -1;
}
int64_t temp = (a * b) / (Measure(a, b));
return temp;
}
Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; }
/**
@ -61,6 +80,35 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_
return SUCCESS;
}
Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape,
int64_t groups) {
auto c0 = GetCubeSizeByDataType(data_type);
if (c0 < 0) {
return ACL_ERROR_GE_DATATYPE_INVALID;
}
int64_t cin_ori = c;
int64_t cout_ori = n / groups;
int64_t cube_k = GetCubeSizeByDataType(data_type);
int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)),
groups);
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = Ceil(groups, e_mult);
auto n1 = Ceil(cout_ori * e_mult, static_cast<int64_t>(kCubeSize));
dst_shape.clear();
dst_shape.push_back(g_dim * c1_dim * h * w);
dst_shape.push_back(n1);
dst_shape.push_back(16);
dst_shape.push_back(cube_k);
if (!IsShapeValid(dst_shape)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
}
Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) {
if (!CheckShapeValid(src_shape, kNchwDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID;
@ -86,6 +134,21 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t
return TransShapeToFz(n, c, h, w, data_type, dst_shape);
}
Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape
, int64_t groups){
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID;
}
auto h = src_shape.at(kHwcnH);
auto w = src_shape.at(kHwcnW);
auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN);
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
}
Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) {
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID;
@ -189,6 +252,80 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
return SUCCESS;
}
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, int64_t groups){
int64_t h_dim = args.src_shape[kHwcnH];
int64_t w_dim = args.src_shape[kHwcnW];
int64_t c_dim = args.src_shape[kHwcnC];
int64_t n_dim = args.src_shape[kHwcnN];
int64_t cin_ori = c_dim;
int64_t cout_ori = n_dim / groups;
if (cin_ori == 0 || cout_ori == 0) {
GELOGE(GRAPH_FAILED, "Cin_ori, cout_ori must not be equal 0, and current cin_ori, cout_ori,"
"groups are %ld %ld %ld",cin_ori, cout_ori, groups);
return GRAPH_FAILED;
}
const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type);
int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)),
groups);
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
int64_t cout_opt = Ceil(e_mult * cout_ori, static_cast<int64_t>(kCubeSize)) * static_cast<int64_t>(kCubeSize);
int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = Ceil(groups, e_mult);
int64_t dim_cin = cin_opt / cube_k;
int64_t data_size = GetSizeByDataType(args.src_data_type);
int64_t size_output_data = g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size;
if (size_output_data == 0) {
result.length = static_cast<size_t>(size_output_data);
return SUCCESS;
}
errno_t ret = EOK;
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data);
return ACL_ERROR_GE_MEMORY_ALLOCATION;
}
ret = memset_s(dst.get(), static_cast<size_t>(size_output_data), 0, static_cast<size_t>(size_output_data));
if (ret != EOK) {
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory, ret is %d", ret);
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}
for (int64_t g = 0; g < groups; g++) {
for (int64_t d = 0; d < kDim; d++) {
for (int64_t c = 0; c < c_dim; c++) {
for (int64_t h = 0; h < h_dim; h++) {
for (int64_t w = 0; w < w_dim; w++) {
for (int64_t n = 0; n < cout_ori; n++) {
int64_t e_val = g % e_mult;
int64_t dst_ci = e_val * cin_ori + c;
int64_t dst_co = e_val * cout_ori + n;
int64_t src_co = g * cout_ori + n;
int64_t tempory = dst_ci % cube_k;
int64_t srx_inx = 0;
int64_t dst_inx = (g / e_mult) * kDim * c1_dim * h_dim * w_dim * cout_opt * cube_k +
d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
(dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k +
h * w_dim * cout_opt * cube_k + w * cout_opt * cube_k +
dst_co * cube_k + tempory;
srx_inx = d * h_dim * w_dim * c_dim * n_dim + h * w_dim * c_dim * n_dim +
w * c_dim * n_dim + c * n_dim + src_co;
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_inx * data_size);
const char *src_data = reinterpret_cast<const char *>(args.data + srx_inx * data_size);
for (int64_t index = 0; index < data_size; index++) {
*dst_data++ = *src_data++;
}
}
}
}
}
}
}
result.data = dst;
result.length = static_cast<size_t>(size_output_data);
return SUCCESS;
}
Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
int64_t h = args.src_shape[kHwcnH];
int64_t w = args.src_shape[kHwcnW];
@ -363,15 +500,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatNhwcToFz(args, result);
}
if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) {
if ((args.src_format == FORMAT_HWCN) && (GetPrimaryFormat(args.dst_format) == FORMAT_FRACTAL_Z)) {
if (GetSubFormat(args.dst_format) > 1) {
return TransFormatHwcnToFzWithGroups(args, result, GetSubFormat(args.dst_format));
}
return TransFormatHwcnToFz(args, result);
}
if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatFromNchwToFz(args, result);
}
return ACL_ERROR_GE_FORMAT_INVALID;
}
@ -384,7 +522,10 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
if ((src_format == FORMAT_HWCN) && (GetPrimaryFormat(dst_format) == FORMAT_FRACTAL_Z)) {
if (GetSubFormat(dst_format) > 1) {
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, GetSubFormat(dst_format));
}
return TransShapeHwcnToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {

@ -34427,6 +34427,240 @@ TEST_F(UtestFormatTransferHwcnFz, fp32_2c_2n_pad) {
}
}
TEST_F(UtestFormatTransferHwcnFz, fp16_1c_1n_with_groups) {
uint16_t data[1 * 1 * 1 * 2] = {19, 88};
uint16_t ret[1 * 1 * 16 * 16] ={19 , 0, 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 88, 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0};
FormatTransferFractalZ transfer;
ge::Format old_format = FORMAT_FRACTAL_Z;
int32_t groups = 2;
ge::Format new_format = static_cast<ge::Format>(ge::GetFormatFromSub(old_format, groups));
TransArgs args{
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, new_format, std::vector<int64_t>({1, 1, 1, 2}),
std::vector<int64_t>({1, 1, 16, 16}), DT_FLOAT16};
TransResult result;
EXPECT_EQ(transfer.TransFormat(args, result), SUCCESS);
EXPECT_EQ(result.length, sizeof(ret) / sizeof(ret[0]) * 2);
for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(result.data.get()))[i], ret[i]);
}
}
TEST_F(UtestFormatTransferHwcnFz, fp16_4c_8n_with_groups_02) {
uint16_t data[3 * 3 * 4 * 8] = {
11 , 99 , 68 , 2 , 14 , 59 , 24 , 100,
4 , 65 , 11 , 7 , 74 , 28 , 71 , 81,
94 , 63 , 80 , 7 , 95 , 29 , 92 , 76,
88 , 68 , 67 , 98 , 82 , 11 , 20 , 68,
36 , 17 , 15 , 89 , 31 , 8 , 51 , 49,
49 , 89 , 79 , 97 , 7 , 91 , 14 , 34,
55 , 40 , 85 , 59 , 31 , 35 , 41 , 89,
4 , 82 , 90 , 48 , 44 , 19 , 9 , 84,
100 , 43 , 7 , 94 , 4 , 91 , 67 , 16,
63 , 79 , 20 , 62 , 55 , 38 , 13 , 61,
98 , 99 , 44 , 0 , 97 , 42 , 65 , 80,
78 , 56 , 26 , 17 , 23 , 22 , 76 , 84,
34 , 88 , 38 , 57 , 37 , 77 , 46 , 28,
48 , 11 , 6 , 18 , 8 , 66 , 24 , 29,
7 , 72 , 34 , 79 , 99 , 14 , 75 , 62,
44 , 98 , 11 , 31 , 4 , 79 , 51 , 37,
84 , 3 , 89 , 74 , 68 , 85 , 17 , 93,
81 , 88 , 38 , 8 , 69 , 82 , 91 , 91,
45 , 42 , 7 , 96 , 81 , 96 , 39 , 35,
93 , 46 , 73 , 7 , 9 , 81 , 5 , 63,
35 , 30 , 27 , 42 , 20 , 52 , 36 , 91,
87 , 1 , 8 , 7 , 78 , 21 , 76 , 97,
52 , 18 , 55 , 57 , 95 , 67 , 3 , 69,
98 , 85 , 75 , 75 , 38 , 3 , 94 , 66,
92 , 27 , 9 , 39 , 5 , 21 , 4 , 48,
55 , 38 , 58 , 84 , 23 , 13 , 71 , 91,
99 , 58 , 58 , 16 , 86 , 45 , 63 , 97,
30 , 10 , 21 , 37 , 78 , 94 , 8 , 49,
18 , 52 , 67 , 65 , 78 , 82 , 74 , 35,
97 , 15 , 43 , 22 , 30 , 87 , 98 , 91,
22 , 88 , 83 , 63 , 79 , 63 , 42 , 74,
29 , 62 , 2 , 97 , 65 , 45 , 76 , 57,
71 , 65 , 0 , 69 , 76 , 41 , 58 , 98,
90 , 3 , 75 , 56 , 41 , 66 , 41 , 96,
44 , 87 , 61 , 26 , 62 , 57 , 49 , 29,
49 , 94 , 90 , 96 , 33 , 32 , 10 , 25};
uint16_t ret[9 * 1 * 16 * 16] ={
11 , 4 , 94 , 88 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
99 , 65 , 63 , 68 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
68 , 11 , 80 , 67 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2 , 7 , 7 , 98 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 14 , 74, 95, 82, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 59 , 28, 29, 11, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 24 , 71, 92, 20, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 ,100 , 81, 76, 68, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
36 , 49 , 55 , 4 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
17 , 89 , 40 , 82 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
15 , 79 , 85 , 90 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
89 , 97 , 59 , 48 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 31 , 7, 31, 44, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 8 , 91, 35, 19, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 51 , 14, 41, 9, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 49 , 34, 89, 84, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
100 , 63 , 98 , 78 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
43 , 79 , 99 , 56 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7 , 20 , 44 , 26 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
94 , 62 , 0 , 17 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 4 , 55, 97, 23, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 91 , 38, 42, 22, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 67 , 13, 65, 76, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 16 , 61, 80, 84, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
34 , 48 , 7 , 44 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
88 , 11 , 72 , 98 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
38 , 6 , 34 , 11 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
57 , 18 , 79 , 31 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 37 , 8, 99, 4, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 77 , 66, 14, 79, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 46 , 24, 75, 51, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 28 , 29, 62, 37, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
84 , 81 , 45 , 93 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
3 , 88 , 42 , 46 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
89 , 38 , 7 , 73 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
74 , 8 , 96 , 7 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 68 , 69, 81, 9, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 85 , 82, 96, 81, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 17 , 91, 39, 5, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 93 , 91, 35, 63, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
35 , 87 , 52 , 98 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
30 , 1 , 18 , 85 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
27 , 8 , 55 , 75 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
42 , 7 , 57 , 75 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 20 , 78, 95, 38, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 52 , 21, 67, 3, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 36 , 76, 3, 94, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 91 , 97, 69, 66, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
92 , 55 , 99 , 30 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
27 , 38 , 58 , 10 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
9 , 58 , 58 , 21 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
39 , 84 , 16 , 37 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 5 , 23, 86, 78, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 21 , 13, 45, 94, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 4 , 71, 63, 8, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 48 , 91, 97, 49, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
18 , 97 , 22 , 29 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
52 , 15 , 88 , 62 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
67 , 43 , 83 , 2 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
65 , 22 , 63 , 97 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 78 , 30, 79, 65, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 82 , 87, 63, 45, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 74 , 98, 42, 76, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 35 , 91, 74, 57, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
71 , 90 , 44 , 49 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
65 , 3 , 87 , 94 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 75 , 61 , 90 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
69 , 56 , 26 , 96 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 76 , 41, 62, 33, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 41 , 66, 57, 32, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 58 , 41, 49, 10, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 98 , 96, 29, 25, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0 , 0 , 0 , 0 , 0 , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
FormatTransferFractalZ transfer;
ge::Format old_format = FORMAT_FRACTAL_Z;
int32_t groups = 2;
ge::Format new_format = static_cast<ge::Format>(ge::GetFormatFromSub(old_format, groups));
TransArgs args{
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, new_format, std::vector<int64_t>({3, 3, 4, 8}),
std::vector<int64_t>({9, 1, 16, 16}), DT_FLOAT16};
TransResult result;
EXPECT_EQ(transfer.TransFormat(args, result), SUCCESS);
EXPECT_EQ(result.length, sizeof(ret) / sizeof(ret[0]) * 2);
for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(result.data.get()))[i], ret[i]);
}
}
TEST_F(UtestFormatTransferHwcnFz, build_transfer_fp32) {
float data[5 * 5 * 31 * 17];
TransArgs args{

Loading…
Cancel
Save