diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl index fdb5e39bc0..25a3f176cc 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl @@ -10,15 +10,7 @@ __kernel void Concat2input_NHWC4(__read_only image2d_t input0, __read_only image if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { return; } - if (axis == 0) { - if (X < input_shape0.x * input_shape0.y) { - FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); - WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); - } else { - FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.x * input_shape0.y))); - WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); - } - } else if (axis == 1) { + if (axis == 1) { if (X < input_shape0.y) { FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); @@ -54,21 +46,7 @@ __kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { return; } - if (axis == 0) { - if (X < input_shape0.x * input_shape0.y) { - FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); - WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); - } else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) { - FLT4 result1 = - READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.x * input_shape0.y))); - WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); - } else { - FLT4 result2 = READ_IMAGE( - input2, smp_none, - (int2)((Y)*input_shape2.w + Z, (X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y))); - WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2); - } - } else if (axis == 1) { + if (axis == 1) { if (X < input_shape0.y) { FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); @@ -121,18 +99,7 @@ __kernel void Concat2input_NC4HW4(__read_only image2d_t input0, __read_only imag } int in_postion_x; int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; - if (axis == 0) { - if (X < (input_shape0.x * input_shape0.y)) { - in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; - FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); - WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); - } else { - in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + - Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y); - FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); - WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); - } - } else if (axis == 1) { + if (axis == 1) { if (X < input_shape0.y) { in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); @@ -181,25 +148,7 @@ __kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only imag } int in_postion_x; int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; - if (axis == 0) { - if (X < (input_shape0.x * input_shape0.y)) { - in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; - FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); - WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); - } else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) { - in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + - Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y); - FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); - WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); - } else { - in_postion_x = ((X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) / input_shape2.y) * - input_shape2.w * input_shape2.y + - Z * input_shape2.y + - (X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) % input_shape2.y; - FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); - WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); - } - } else if (axis == 1) { + if (axis == 1) { if (X < input_shape0.y) { in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl index 836f8a2d00..e6f3c9da2e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl @@ -59,10 +59,10 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_float(__global float4 *src_data, __wr int X = get_global_id(0); int Y = get_global_id(1); int Z = get_global_id(2); - if (X >= size.x || Y >= size.y || Z >= size.z) { + if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { return; } - int offset = (X * shape.z + Y) * shape.w + Z * 4; + int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; __global float *src_addr = (__global float *)src_data; src_addr += offset; FLT4 data = (FLT4)(0.f); @@ -79,17 +79,18 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_float(__global float4 *src_data, __wr data.z = (FLT)src_addr[2]; } } - WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data); + int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; + WRITE_IMAGE(dst_data, (int2)(Y, pos_ix), data); } __kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size, int4 shape) { int X = get_global_id(0); int Y = get_global_id(1); int Z = get_global_id(2); - if (X >= size.x || Y >= size.y || Z >= size.z) { + if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { return; } - int offset = (X * shape.z + Y) * shape.w + Z * 4; + int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; __global half *src_addr = (__global half *)src_data; src_addr += offset; FLT4 data = (FLT4)(0.f); @@ -106,7 +107,8 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __writ data.z = (FLT)src_addr[2]; } } - WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data); + int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; + WRITE_IMAGE(dst_data, (int2)(Y, pos_ix), data); } __kernel void to_format_NHWC4_to_NHWC4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size, int4 shape) { @@ -227,11 +229,12 @@ __kernel void to_format_NC4HW4_to_NHWC_BUF_float(__read_only image2d_t src_data, int X = get_global_id(0); int Y = get_global_id(1); int Z = get_global_id(2); - if (X >= size.x || Y >= size.y || Z >= size.z) { + if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { return; } - float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X))); - int offset = (X * shape.z + Y) * shape.w + Z * 4; + int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; + float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, pos_ix))); + int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; __global float *dst_addr = (__global float *)dst_data; dst_addr += offset; if ((Z + 1) * 4 <= shape.w) { @@ -253,11 +256,12 @@ __kernel void to_format_NC4HW4_to_NHWC_BUF_half(__read_only image2d_t src_data, int X = get_global_id(0); int Y = get_global_id(1); int Z = get_global_id(2); - if (X >= size.x || Y >= size.y || Z >= size.z) { + if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { return; } - half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X))); - int offset = (X * shape.z + Y) * shape.w + Z * 4; + int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; + half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, pos_ix))); + int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; __global half *dst_addr = (__global half *)dst_data; dst_addr += offset; if ((Z + 1) * 4 <= shape.w) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc index e123007f4e..5495d02404 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -49,6 +49,26 @@ int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) *img_size = vec; return RET_OK; } + +int ConcatOpenCLKernel::RunAxis0() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator_ = ocl_runtime->GetAllocator(); + std::vector img_size; + auto dst_data = out_tensors_[0]->MutableData(); + auto dst_origin = cl::array{0, 0, 0}; + cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(dst_data)); + for (int i = 0; i < in_tensors_.size(); i++) { + auto src_data = in_tensors_[i]->MutableData(); + allocator_->GetImageSize(src_data, &img_size); + auto src_origin = cl::array{0, 0, 0}; + auto region = cl::array{img_size[0], img_size[1], 1}; + cl::Image2D *input_image = reinterpret_cast(allocator_->GetImage(src_data)); + ocl_runtime->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region); + dst_origin[1] += region[1]; + } + return RET_OK; +} + int ConcatOpenCLKernel::Init() { if (in_tensors_[0]->shape().size() != 4) { MS_LOG(ERROR) << " only support dim = 4 "; @@ -98,6 +118,19 @@ int ConcatOpenCLKernel::Init() { int ConcatOpenCLKernel::ReSize() { return RET_OK; } +int ConcatOpenCLKernel::GetSumShape(std::vector *sum_shape, std::vector *in_shape) { + std::vector temp_sum = {0, 0, 0, 0}; + for (int i = 0; i < in_tensors_.size(); ++i) { + auto temp = in_tensors_[i]->shape(); + for (int j = 0; j < temp.size(); ++j) { + in_shape->push_back(temp[j]); + temp_sum.at(j) += temp[j]; + sum_shape->push_back(temp_sum.at(j)); + } + } + return RET_OK; +} + int ConcatGetBiggestDividerWithPriority(int number, int max_divider) { if (number % 8 == 0 && max_divider >= 8) { return number / 8; @@ -133,6 +166,9 @@ void ConcatGetWorkGroup(const std::vector &global, std::vector * int ConcatOpenCLKernel::Run() { MS_LOG(DEBUG) << this->name() << " Running! "; auto param = reinterpret_cast(this->op_parameter_); + if (param->axis_ == 0) { + return RunAxis0(); + } auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto input1_shape = in_tensors_[0]->shape(); @@ -151,6 +187,7 @@ int ConcatOpenCLKernel::Run() { std::vector local = {1, 1, 1}; // init local std::vector global = {OH, OW, OC}; ConcatGetWorkGroup(global, &local, max_global[0]); + GetSumShape(&sum_shape, &in_shape); int arg_cn = 0; if (in_tensors_.size() == 2) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h index 8832330d49..5f08f21102 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h @@ -38,10 +38,17 @@ class ConcatOpenCLKernel : public OpenCLKernel { int ReSize() override; int Run() override; + + int RunAxis0(); + int GetImageSize(size_t idx, std::vector *img_size) override; + int GetSumShape(std::vector *sum_shape, std::vector *in_shape); + private: cl::Kernel kernel_; + std::vector sum_shape; + std::vector in_shape; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index c8cefb6338..74822fd0cd 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -62,19 +62,19 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { constexpr int INPUT_NUM = 2; std::array, INPUT_NUM> input_shapes = {std::vector{1, 19, 19, 96}, std::vector{1, 19, 19, 96}}; - std::vector output_shape = {2, 19, 19, 96}; + std::vector output_shape = {1, 19, 19, 192}; auto data_type = kNumberTypeFloat16; auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); std::vector inputs; for (auto &shape : input_shapes) { - auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC4, tensor_type); + auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type); inputs.push_back(input_temp); if (input_temp == nullptr) { MS_LOG(INFO) << " new input_tensor failed "; return; } } - auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); + auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); if (output_tensor == nullptr) { MS_LOG(INFO) << " new output_tensor failed "; for (auto tensor : inputs) { @@ -97,7 +97,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { } return; } - param->axis_ = 0; + param->axis_ = 3; auto *concat_kernel = new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); if (concat_kernel == nullptr) { @@ -111,6 +111,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { delete param; return; } + concat_kernel->SetFormatType(schema::Format_NC4HW4); concat_kernel->Init(); // to do allocate memory for inputs and outputs for (auto &input_tensor : inputs) { @@ -229,8 +230,9 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) { delete param; return; } + concat_kernel->SetFormatType(schema::Format_NC4HW4); concat_kernel->Init(); - // to do allocate memory for inputs and outputs + // to do allocate memory for inputs for (auto &input_tensor : inputs) { input_tensor->MallocData(allocator); }