From ee8995b6ed89071dd8db95574eb1669faaed5077 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Thu, 29 Oct 2020 19:11:37 +0800 Subject: [PATCH] reduce support WC --- .../src/runtime/kernel/opencl/cl/reduce.cl | 91 +++++++++++++++++++ .../runtime/kernel/opencl/kernel/reduce.cc | 33 +++++-- .../src/runtime/kernel/opencl/kernel/reduce.h | 1 + .../kernel/opencl/subgraph_opencl_kernel.cc | 1 + .../src/runtime/kernel/opencl/reduce_tests.cc | 36 +++++++- 5 files changed, 150 insertions(+), 12 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/reduce.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/reduce.cl index 2bfb66fe7b..a69feab9d6 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/reduce.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/reduce.cl @@ -84,3 +84,94 @@ __kernel void sum_local_NHWC4(__read_only image2d_t src_data, __write_only image } WRITE_IMAGE(dst_data, (int2)(X, 0), TO_FLT4(result)); } + +__kernel void mean_WC_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { + int X = get_global_id(0); // H + if (X >= size.x) { + return; + } + float4 result = (float4)0.f; + for (int w = 0; w < size.y; w++) { + for (int c = 0; c < size.z; c++) { + result += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X))); + } + } + + result /= size.y * size.w; + FLT4 result2 = (FLT4)(0.f); + result2.x = dot(TO_FLT4(result), (FLT4)(1.f)); + WRITE_IMAGE(dst_data, (int2)(0, X), result2); +} + +__kernel void mean_WC_local_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { + int X = get_global_id(0); // H + int localy = get_local_id(1); + int localz = get_local_id(2); + if (X >= size.x) return; + __local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD]; + temp[localy][localz] = (float4)0.f; + for (int w = localy; w < size.y; w += LOCAL_CACHE_THREAD) { + for (int c = localz; c < size.z; c += LOCAL_CACHE_THREAD) { + temp[localy][localz] += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X))); + } + } + barrier(CLK_LOCAL_MEM_FENCE); + if (localz == 0) { + for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { + temp[localy][0] += temp[localy][i]; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + float4 result = temp[0][0]; + for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { + result += temp[i][0]; + } + result /= size.y * size.w; + FLT4 result2 = (FLT4)(0.f); + result2.x = dot(TO_FLT4(result), (FLT4)(1.f)); + WRITE_IMAGE(dst_data, (int2)(0, X), result2); +} + +__kernel void sum_WC_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { + int X = get_global_id(0); // H + if (X >= size.x) { + return; + } + FLT4 result = (FLT4)0.f; + for (int w = 0; w < size.y; w++) { + for (int c = 0; c < size.z; c++) { + result += READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X)); + } + } + FLT4 result2 = (FLT4)(0.f); + result2.x = dot(TO_FLT4(result), (FLT4)(1.f)); + WRITE_IMAGE(dst_data, (int2)(0, X), result2); +} + +__kernel void sum_WC_local_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { + int X = get_global_id(0); // H + int localy = get_local_id(1); + int localz = get_local_id(2); + if (X >= size.x) return; + __local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD]; + temp[localy][localz] = (float4)0.f; + for (int w = localy; w < size.y; w += LOCAL_CACHE_THREAD) { + for (int c = localz; c < size.z; c += LOCAL_CACHE_THREAD) { + temp[localy][localz] += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X))); + } + } + barrier(CLK_LOCAL_MEM_FENCE); + if (localz == 0) { + for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { + temp[localy][0] += temp[localy][i]; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + float4 result = temp[0][0]; + for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { + result += temp[i][0]; + } + FLT4 result2 = (FLT4)(0.f); + result2.x = dot(TO_FLT4(result), (FLT4)(1.f)); + WRITE_IMAGE(dst_data, (int2)(0, X), result2); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc index 214f5b5691..28e14005b7 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc @@ -51,24 +51,34 @@ int ReduceOpenCLKernel::Init() { MS_LOG(ERROR) << "not supported reduce type:" << reduce_param->mode_; return RET_PARAM_INVALID; } - if (reduce_param->num_axes_ != 2 || ((reduce_param->axes_[0] != 1 || reduce_param->axes_[1] != 2) && - (reduce_param->axes_[0] != 2 || reduce_param->axes_[1] != 1))) { - MS_LOG(ERROR) << "reduce op only support axes HW"; + if (reduce_param->num_axes_ != 2) { + MS_LOG(ERROR) << "reduce op only support axes=2"; + return RET_PARAM_INVALID; + } + bool hw_reduce = (reduce_param->axes_[0] == 1 && reduce_param->axes_[1] == 2) || + (reduce_param->axes_[0] == 2 && reduce_param->axes_[1] == 1); + wc_reduce_ = (reduce_param->axes_[0] == 2 && reduce_param->axes_[1] == 3) || + (reduce_param->axes_[0] == 3 && reduce_param->axes_[1] == 2); + if (!hw_reduce && !wc_reduce_) { + MS_LOG(ERROR) << "reduce op only support axis (1,2) or (2,3)"; + return RET_PARAM_INVALID; + } + if (wc_reduce_ && reduce_param->keep_dims_ == false) { + MS_LOG(ERROR) << "reduce axis (2,3) should keep dims"; return RET_PARAM_INVALID; } std::string kernel_name = reduce_type2str.at(reduce_param->mode_); - if (in_tensors_[0]->shape()[1] >= LOCAL_CACHE_THREAD || in_tensors_[0]->shape()[2] >= LOCAL_CACHE_THREAD) { + if (wc_reduce_) { + kernel_name += "_WC"; + } + if (in_tensors_[0]->shape()[reduce_param->axes_[0]] >= LOCAL_CACHE_THREAD || + in_tensors_[0]->shape()[reduce_param->axes_[1]] >= LOCAL_CACHE_THREAD) { use_local_ = true; kernel_name += "_local"; } kernel_name += "_NHWC4"; enable_fp16_ = ocl_runtime_->GetFp16Enable(); - if (in_tensors_[0]->shape().back() != out_tensors_[0]->shape().back()) { - MS_LOG(ERROR) << "Reduce input channel " << in_tensors_[0]->shape().back() << " should equal output channel" - << out_tensors_[0]->shape().back(); - return mindspore::lite::RET_ERROR; - } #ifdef PROGRAM_WITH_IL kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); #else @@ -109,7 +119,10 @@ int ReduceOpenCLKernel::Run() { local = {1, LOCAL_CACHE_THREAD, LOCAL_CACHE_THREAD}; } std::vector global = {static_cast(c4), 1, 1}; - cl_int4 size = {h, w, c4, 1}; + if (wc_reduce_) { + global = {static_cast(h), 1, 1}; + } + cl_int4 size = {h, w, c4, c}; int arg_idx = 0; ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h index 83ffdc453d..56b73b1463 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h @@ -40,6 +40,7 @@ class ReduceOpenCLKernel : public OpenCLKernel { bool enable_fp16_{false}; std::vector nhwc_shape_; bool use_local_{false}; + bool wc_reduce_{false}; static const size_t LOCAL_CACHE_THREAD{16}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc index c02eff91a1..8ce5ca1c13 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc @@ -79,6 +79,7 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector &in_te new_tensor = nullptr; return RET_ERROR; } + parameter->op_parameter.type_ = mindspore::schema::PrimitiveType_ToFormat; parameter->src_format = src_format; parameter->dst_format = dst_format; parameter->out_mem_type = mem_type; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc index 33eb04c555..5a3c0e216e 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc @@ -30,7 +30,7 @@ class TestReduceOpenCL : public mindspore::CommonTest { }; void RunTestCaseReduce(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16, - int reduce_mode) { + int reduce_mode, bool WC = false) { auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); @@ -43,6 +43,11 @@ void RunTestCaseReduce(const std::vector &shape, void *input_data, void *ou } param->axes_[0] = 1; param->axes_[1] = 2; + if (WC) { + param->axes_[0] = 2; + param->axes_[1] = 3; + param->keep_dims_ = true; + } param->num_axes_ = 2; param->mode_ = reduce_mode; int n = shape[0]; @@ -58,8 +63,11 @@ void RunTestCaseReduce(const std::vector &shape, void *input_data, void *ou return; } std::vector out_shape = {n, c}; + if (WC) { + out_shape = {n, h, 1, 1}; + } auto tensor_out_ptr = std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), - out_shape, schema::Format_NC); + out_shape, WC ? schema::Format_NHWC : schema::Format_NC); auto tensor_out = tensor_out_ptr.get(); if (tensor_out == nullptr) { MS_LOG(ERROR) << "tensor_out create error."; @@ -152,4 +160,28 @@ TEST_F(TestReduceOpenCL, ReduceSumFp16) { RunTestCaseReduce(shape, input_data.data(), output_data.data(), true, schema::ReduceMode_ReduceSum); } + +TEST_F(TestReduceOpenCL, ReduceMeanWCFp32) { + int n = 1; + int h = 3; + int w = 2; + int c = 2; + std::vector shape = {n, h, w, c}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; + std::vector output_data = {1.5f, 5.5f, 9.5f}; + + RunTestCaseReduce(shape, input_data.data(), output_data.data(), false, schema::ReduceMode_ReduceMean, true); +} + +TEST_F(TestReduceOpenCL, ReduceSumWCFp32) { + int n = 1; + int h = 3; + int w = 2; + int c = 2; + std::vector shape = {n, h, w, c}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; + std::vector output_data = {6.0f, 22.0f, 38.0f}; + + RunTestCaseReduce(shape, input_data.data(), output_data.data(), false, schema::ReduceMode_ReduceSum, true); +} } // namespace mindspore