!7202 fix bug for opencl batch_to_space op

Merge pull request !7202 from wandongdong/master
pull/7202/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 16e4442a52

@ -5,13 +5,16 @@ __kernel void batch_to_space_nd_NHWC4(__read_only image2d_t src_data, __write_on
int X = get_global_id(0); // c
int Y = get_global_id(1); // w
int Z = get_global_id(2); // h*n
if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {
if (X >= src_size.x || Y >= src_size.y || Y >= src_size.z) {
return;
}
for (int i = 0; i < block_size.x; ++i) {
for (int j = 0; j < block_size.y; ++j) {
int Y_dst = (Y * block_size.y + j);
int Z_dst = Z * block_size.x + i;
if (Y_dst >= dst_size.y || Z_dst >= dst_size.z) {
continue;
}
int Y_org = (Y_dst + paddings.z) / block_size.y;
int Z_org = (Z_dst + paddings.x) / block_size.x;
int Z_com = (i * block_size.y + j) * src_size.z + Z_org;

@ -101,11 +101,9 @@ int BatchToSpaceNDOpenCLKernel::Run() {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t CI4 = UP_DIV(in_tensors_[0]->Channel(), C4NUM);
cl_int4 src_size = {
(cl_int)CI4, in_tensors_[0]->Width(),
in_tensors_[0]->Height() * in_tensors_[0]->Batch() / param->block_shape_[0] / param->block_shape_[1], 1};
cl_int4 dst_size = {(cl_int)CO4, out_tensors_[0]->Width() / param->block_shape_[1],
out_tensors_[0]->Height() / param->block_shape_[0] * out_tensors_[0]->Batch(), 1};
cl_int4 src_size = {(cl_int)CI4, in_tensors_[0]->Width(), in_tensors_[0]->Height() * out_tensors_[0]->Batch(), 1};
std::vector<int> out_shape = out_tensors_[0]->shape();
cl_int4 dst_size = {(cl_int)CO4, out_shape[2], out_shape[1], out_shape[0]};
cl_int2 block_size = {param->block_shape_[0], param->block_shape_[1]};
cl_int4 paddings = {param->crops_[0], param->crops_[1], param->crops_[2], param->crops_[3]};
std::vector<size_t> local = {1, 1, 1};

@ -95,7 +95,7 @@ void test_main_batch_to_space_nd(void *input_data, void *correct_data, const std
CommonTest::CompareOutputData<T>(output_data, static_cast<T *>(correct_data), outputs[0]->ElementsNum(), 0.0001);
delete sub_graph;
}
TEST_F(TestBatchToSpaceNDOpenCL, NHWC4H2W2Pad2222) {
TEST_F(TestBatchToSpaceNDOpenCL, NHWC4H2W2Pad2020) {
std::vector<int> input_shape{4, 5, 5, 4};
BatchToSpaceParameter *param = std::make_unique<BatchToSpaceParameter>().release();
if (param == nullptr) {
@ -104,9 +104,9 @@ TEST_F(TestBatchToSpaceNDOpenCL, NHWC4H2W2Pad2222) {
param->block_shape_[0] = 2;
param->block_shape_[1] = 2;
param->crops_[0] = 2;
param->crops_[1] = 2;
param->crops_[1] = 0;
param->crops_[2] = 2;
param->crops_[3] = 2;
param->crops_[3] = 0;
float input_data[] = {
172, 47, 117, 192, 67, 251, 195, 103, 9, 211, 21, 242, 36, 87, 70, 216, 88, 140, 58, 193, 230, 39, 87,
174, 88, 81, 165, 25, 77, 72, 9, 148, 115, 208, 243, 197, 254, 79, 175, 192, 82, 99, 216, 177, 243, 29,
@ -139,6 +139,51 @@ TEST_F(TestBatchToSpaceNDOpenCL, NHWC4H2W2Pad2222) {
schema::Format format = schema::Format_NHWC;
test_main_batch_to_space_nd<float>(input_data, correct_data, input_shape, param, data_type, format);
}
TEST_F(TestBatchToSpaceNDOpenCL, NHWC4H3W3Pad0101) {
std::vector<int> input_shape{9, 3, 3, 4};
BatchToSpaceParameter *param = std::make_unique<BatchToSpaceParameter>().release();
if (param == nullptr) {
return;
}
param->block_shape_[0] = 3;
param->block_shape_[1] = 3;
param->crops_[0] = 0;
param->crops_[1] = 1;
param->crops_[2] = 0;
param->crops_[3] = 1;
float input_data[] = {
172, 47, 117, 192, 67, 251, 195, 103, 9, 211, 21, 242, 36, 87, 70, 216, 88, 140, 58, 193, 230, 39,
87, 174, 88, 81, 165, 25, 77, 72, 9, 148, 115, 208, 243, 197, 254, 79, 175, 192, 82, 99, 216, 177,
243, 29, 147, 147, 142, 167, 32, 193, 9, 185, 127, 32, 31, 202, 244, 151, 163, 254, 203, 114, 183, 28,
34, 128, 128, 164, 53, 133, 38, 232, 244, 17, 79, 132, 105, 42, 186, 31, 120, 1, 65, 231, 169, 57,
35, 102, 119, 11, 174, 82, 91, 128, 142, 99, 53, 140, 121, 170, 84, 203, 68, 6, 196, 47, 127, 244,
131, 204, 100, 180, 232, 78, 143, 148, 227, 186, 23, 207, 141, 117, 85, 48, 49, 69, 169, 163, 192, 95,
197, 94, 0, 113, 178, 36, 162, 48, 93, 131, 98, 42, 205, 112, 231, 149, 201, 127, 0, 138, 114, 43,
186, 127, 23, 187, 130, 121, 98, 62, 163, 222, 123, 195, 82, 174, 227, 148, 209, 50, 155, 14, 41, 58,
193, 36, 10, 86, 43, 104, 11, 2, 51, 80, 32, 182, 128, 38, 19, 174, 42, 115, 184, 188, 232, 77,
30, 24, 125, 2, 3, 94, 226, 107, 13, 112, 40, 72, 19, 95, 72, 154, 194, 248, 180, 67, 236, 61,
14, 96, 4, 195, 237, 139, 252, 86, 205, 121, 109, 75, 184, 16, 152, 157, 149, 110, 25, 208, 188, 121,
118, 117, 189, 83, 161, 104, 160, 228, 251, 251, 121, 70, 213, 31, 13, 71, 184, 152, 79, 41, 18, 40,
182, 207, 11, 166, 111, 93, 249, 129, 223, 118, 44, 216, 125, 24, 67, 210, 239, 3, 234, 204, 230, 35,
214, 254, 189, 197, 215, 43, 32, 11, 104, 212, 138, 182, 235, 165, 125, 156, 111, 232, 2, 27, 211, 217,
151, 53, 51, 174, 148, 181, 29, 67, 35, 39, 137, 73, 41, 151, 131, 46};
float correct_data[] = {
172, 47, 117, 192, 254, 79, 175, 192, 38, 232, 244, 17, 67, 251, 195, 103, 82, 99, 216, 177, 79, 132,
105, 42, 9, 211, 21, 242, 243, 29, 147, 147, 127, 244, 131, 204, 205, 112, 231, 149, 43, 104, 11, 2,
100, 180, 232, 78, 201, 127, 0, 138, 51, 80, 32, 182, 143, 148, 227, 186, 114, 43, 186, 127, 180, 67,
236, 61, 121, 70, 213, 31, 189, 197, 215, 43, 14, 96, 4, 195, 13, 71, 184, 152, 32, 11, 104, 212,
237, 139, 252, 86, 79, 41, 18, 40, 36, 87, 70, 216, 142, 167, 32, 193, 65, 231, 169, 57, 88, 140,
58, 193, 9, 185, 127, 32, 35, 102, 119, 11, 230, 39, 87, 174, 31, 202, 244, 151, 23, 207, 141, 117,
23, 187, 130, 121, 42, 115, 184, 188, 85, 48, 49, 69, 98, 62, 163, 222, 232, 77, 30, 24, 169, 163,
192, 95, 123, 195, 82, 174, 205, 121, 109, 75, 182, 207, 11, 166, 125, 156, 111, 232, 184, 16, 152, 157,
111, 93, 249, 129, 2, 27, 211, 217, 149, 110, 25, 208, 223, 118, 44, 216, 88, 81, 165, 25, 163, 254,
203, 114, 142, 99, 53, 140, 77, 72, 9, 148, 183, 28, 34, 128, 121, 170, 84, 203, 115, 208, 243, 197,
128, 164, 53, 133, 197, 94, 0, 113, 227, 148, 209, 50, 226, 107, 13, 112, 178, 36, 162, 48, 155, 14,
41, 58, 40, 72, 19, 95, 93, 131, 98, 42, 193, 36, 10, 86};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_batch_to_space_nd<float>(input_data, correct_data, input_shape, param, data_type, format);
}
TEST_F(TestBatchToSpaceNDOpenCL, NC4HW4H2W2Pad2222) {
std::vector<int> input_shape{4, 5, 5, 4};
BatchToSpaceParameter *param = std::make_unique<BatchToSpaceParameter>().release();

Loading…
Cancel
Save