From b1aa1a1d177fb8f40f66c7c89efbdf4ff418ae92 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Wed, 14 Oct 2020 16:02:01 +0800 Subject: [PATCH] conv2d transpose support 4x4 8x8 and fullconnection support c%4!=0 --- ...2d_transpose2x2.cl => conv2d_transpose.cl} | 90 ++++++++++++++++++- .../kernel/opencl/kernel/conv2d_transpose.cc | 31 +++---- .../kernel/opencl/kernel/fullconnection.cc | 28 +++--- 3 files changed, 114 insertions(+), 35 deletions(-) rename mindspore/lite/src/runtime/kernel/opencl/cl/{conv2d_transpose2x2.cl => conv2d_transpose.cl} (56%) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl similarity index 56% rename from mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl rename to mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl index ac76b8742f..722c7d564f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl @@ -3,7 +3,7 @@ __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, int4 src_size, int4 dst_size) { - int h = get_global_id(0); + int h = get_global_id(2); int kh = h % 2; int src_h = h / 2; src_h = src_h * 2; @@ -11,7 +11,7 @@ __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global int kw = w % 2; int src_w = w / 2; src_w = src_w * 2; - int co = get_global_id(2); + int co = get_global_id(0); if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; FLT4 r0 = (FLT4)(0.f); FLT4 r1 = (FLT4)(0.f); @@ -59,7 +59,7 @@ __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, int4 src_size, int4 dst_size) { - int h = get_global_id(0); + int h = get_global_id(2); int kh = h % 2; int src_h = h / 2; src_h = src_h * 2; @@ -67,7 +67,7 @@ __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __globa int kw = w % 2; int src_w = w / 2; src_w = src_w * 2; - int co = get_global_id(2); + int co = get_global_id(0); if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; FLT4 r0 = (FLT4)(0.f); FLT4 r1 = (FLT4)(0.f); @@ -111,3 +111,85 @@ __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __globa WRITE_IMAGE(dst_data, (int2)(2 * src_w + kw + 2, co * dst_size.x + 2 * src_h + kh), r2); WRITE_IMAGE(dst_data, (int2)(2 * src_w + kw + 2, co * dst_size.x + 2 * src_h + kh + 2), r3); } + +__kernel void conv2d_transpose_NHWC4(__read_only image2d_t src_data, __global FLT16 *weight, + __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, + int2 stride, int2 padding, int4 src_size, int4 dst_size) { + int dst_h = get_global_id(2); + int rem_h = dst_h % stride.x; + int ceil_h = dst_h / stride.x; + dst_h = ceil_h * stride.x * 2 + rem_h; + int dst_w = get_global_id(1); + int rem_w = dst_w % stride.y; + int ceil_w = dst_w / stride.y; + dst_w = ceil_w * stride.y * 2 + rem_w; + int dst_c = get_global_id(0); + if (dst_h >= dst_size.x || dst_w >= dst_size.y || dst_c >= dst_size.z) return; + int weight_base = dst_c * src_size.z * kernel_size.x * kernel_size.y; + FLT4 r0 = (FLT4)(0.f); + FLT4 r1 = (FLT4)(0.f); + FLT4 r2 = (FLT4)(0.f); + FLT4 r3 = (FLT4)(0.f); + int kh_start = dst_h + padding.x; + int kw_start = dst_w + padding.y; + int kh_end = kh_start - kernel_size.x; + int kw_end = kw_start - kernel_size.y; + int src_h = kh_start / stride.x; + int kh = src_h * stride.x; + int src_w = kw_start / stride.y; + int kw = src_w * stride.y; + for (; kh > kh_end; src_h -= 1, kh -= stride.x) { + int out0_src_h = src_h; + int out1_src_h = src_h + 1; + int kernel_h = kh_start - kh; + int src_w_copy = src_w; + int kw_copy = kw; + for (; kw_copy > kw_end; src_w_copy -= 1, kw_copy -= stride.y) { + int out0_src_w = src_w_copy; + int out1_src_w = src_w_copy + 1; + int kernel_w = kw_start - kw_copy; + int weight_offset = weight_base + (kernel_h * kernel_size.y + kernel_w) * src_size.z; + for (int ci = 0; ci < src_size.z; ++ci) { + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(out0_src_w * src_size.z + ci, out0_src_h)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(out0_src_w * src_size.z + ci, out1_src_h)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(out1_src_w * src_size.z + ci, out0_src_h)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(out1_src_w * src_size.z + ci, out1_src_h)); + FLT16 weight_cache = weight[weight_offset++]; + r0 += x0.x * weight_cache.s0123; + r0 += x0.y * weight_cache.s4567; + r0 += x0.z * weight_cache.s89ab; + r0 += x0.w * weight_cache.scdef; + + r1 += x1.x * weight_cache.s0123; + r1 += x1.y * weight_cache.s4567; + r1 += x1.z * weight_cache.s89ab; + r1 += x1.w * weight_cache.scdef; + + r2 += x2.x * weight_cache.s0123; + r2 += x2.y * weight_cache.s4567; + r2 += x2.z * weight_cache.s89ab; + r2 += x2.w * weight_cache.scdef; + + r3 += x3.x * weight_cache.s0123; + r3 += x3.y * weight_cache.s4567; + r3 += x3.z * weight_cache.s89ab; + r3 += x3.w * weight_cache.scdef; + } + } + } + FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(dst_c, 0)); + r0 += bias_val; + r1 += bias_val; + r2 += bias_val; + r3 += bias_val; + WRITE_IMAGE(dst_data, (int2)(dst_w * dst_size.z + dst_c, dst_h), r0); + if (dst_h + stride.x < dst_size.x && dst_w < dst_size.y) { + WRITE_IMAGE(dst_data, (int2)(dst_w * dst_size.z + dst_c, dst_h + stride.x), r1); + } + if (dst_h < dst_size.x && dst_w + stride.y < dst_size.y) { + WRITE_IMAGE(dst_data, (int2)((dst_w + stride.y) * dst_size.z + dst_c, dst_h), r2); + } + if (dst_h + stride.x < dst_size.x && dst_w + stride.y < dst_size.y) { + WRITE_IMAGE(dst_data, (int2)((dst_w + stride.y) * dst_size.z + dst_c, dst_h + stride.x), r3); + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 5dbc6ed016..e954a4ee0f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -20,7 +20,7 @@ #include "nnacl/fp32/common_func.h" #include "src/kernel_registry.h" #ifndef PROGRAM_WITH_IL -#include "src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl.inc" +#include "src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc" #endif using mindspore::kernel::KERNEL_ARCH::kGPU; @@ -31,22 +31,20 @@ namespace mindspore::kernel { int Conv2dTransposeOpenCLKernel::Init() { ConvParameter *param = reinterpret_cast(op_parameter_); - if (param->kernel_h_ != 2 || param->kernel_w_ != 2 || param->stride_h_ != 2 || param->stride_w_ != 2) { - MS_LOG(ERROR) << "only support kh=kw=2 and stride_h=stride_w=2."; + if (param->pad_l_ != param->pad_r_ || param->kernel_h_ - param->stride_h_ != 2 * param->pad_l_ || + param->pad_u_ != param->pad_d_ || param->kernel_w_ - param->stride_w_ != 2 * param->pad_u_) { + MS_LOG(ERROR) << "only support kernel - stride == 2 * pad"; return RET_ERROR; } - if (param->pad_u_ != 0 || param->pad_l_ != 0) { - MS_LOG(ERROR) << "only support pad =0."; - return RET_ERROR; - } - std::string kernel_name = "conv2d_transpose2x2_" + std::string(EnumNameFormat(op_format_)); + std::string kernel_name = "conv2d_transpose"; + kernel_name += "_" + std::string(EnumNameFormat(op_format_)); enable_fp16_ = ocl_runtime_->GetFp16Enable(); #ifdef PROGRAM_WITH_IL kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); #else - std::string source = conv2d_transpose2x2_source; + std::string source = conv2d_transpose_source; std::set build_options; - std::string program_name = "conv2d_transpose2x2"; + std::string program_name = "conv2d_transpose"; ocl_runtime_->LoadSource(program_name, source); ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); #endif @@ -181,19 +179,22 @@ int Conv2dTransposeOpenCLKernel::Run() { int co4 = UP_DIV(co, C4NUM); int kh = param->kernel_h_; int kw = param->kernel_w_; - int pad = param->pad_u_; + int pad_h = param->pad_l_; + int pad_w = param->pad_u_; + int stride_h = param->stride_h_; + int stride_w = param->stride_w_; int oh = out_tensors_[0]->shape()[1]; int ow = out_tensors_[0]->shape()[2]; int h = in_tensors_[0]->shape()[1]; int w = in_tensors_[0]->shape()[2]; // local size should less than MAX_GROUP_SIZE std::vector local = {16, 1, 16}; - std::vector global = {UP_ROUND((size_t)UP_ROUND(oh / 2, 2), local[0]), - UP_ROUND((size_t)UP_ROUND(ow / 2, 2), local[1]), UP_ROUND(co4, local[2])}; + std::vector global = {UP_ROUND(co4, local[0]), UP_ROUND((size_t)UP_ROUND(ow / 2, stride_w), local[1]), + UP_ROUND((size_t)UP_ROUND(oh / 2, stride_h), local[2])}; cl_int2 kernel_size = {kh, kw}; - cl_int2 stride = {2, 2}; - cl_int2 padding = {pad, pad}; + cl_int2 stride = {stride_h, stride_w}; + cl_int2 padding = {pad_h, pad_w}; cl_int4 src_size = {h, w, UP_DIV(ci, C4NUM), 1}; cl_int4 dst_size = {oh, ow, UP_DIV(co, C4NUM), 1}; int arg_cnt = 0; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc index 9b1013fe6b..191cdfd7d2 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc @@ -47,10 +47,6 @@ int FullConnectionOpenCLKernel::Init() { return RET_ERROR; } if (in_tensors_[0]->shape().size() == 4) { - if (in_tensors_[0]->shape()[3] % C4NUM != 0) { - MS_LOG(ERROR) << "fullconnection only support input shape channel % 4 = 0 if input shape size = 4"; - return RET_ERROR; - } inShape = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], in_tensors_[0]->shape()[3]}; } else { @@ -92,30 +88,30 @@ int FullConnectionOpenCLKernel::ReSize() { return RET_OK; } void FullConnectionOpenCLKernel::PadWeight() { // ABMCI @ ABCICO = ABMCO auto allocator = ocl_runtime_->GetAllocator(); - int ci = inShape[1] * inShape[2] * inShape[3]; + int ci = inShape[3]; int ci4 = UP_DIV(ci, C4NUM); int co = outShape[1]; int co4 = UP_DIV(co, C4NUM); - int a = 1; - int b = 1; + int h = inShape[1]; + int w = inShape[2]; size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); - padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size); + padWeight_ = allocator->Malloc(h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size); padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true); auto padWeightFp32 = reinterpret_cast(padWeight_); auto padWeightFp16 = reinterpret_cast(padWeight_); - memset(padWeight_, 0x00, a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size); + memset(padWeight_, 0x00, h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size); auto originWeightFp32 = reinterpret_cast(in_tensors_.at(kWeightIndex)->data_c()); auto originWeightFp16 = reinterpret_cast(in_tensors_.at(kWeightIndex)->data_c()); bool isModelFp16 = in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16; // pad weight - // ABCICO -> AB(CI4)(CO4)(4 from CO)(4 from CI) - // if tranposeB, ABCOCI -> AB(CI4)(CO4)(4 from CO)(4 from CI) + // HWCICO -> (HWCI4)(CO4)(4 from CO)(4 from CI) + // if tranposeB, COHWCI -> (HWCI4)(CO4)(4 from CO)(4 from CI) int index = 0; - for (int aa = 0; aa < a; aa++) { - for (int bb = 0; bb < b; bb++) { - int baseAB = (aa * b + bb) * ci * co; + for (int hh = 0; hh < h; hh++) { + for (int ww = 0; ww < w; ww++) { + int baseHW = hh * w + ww; for (int i = 0; i < ci4; ++i) { for (int j = 0; j < co4; ++j) { for (int k = 0; k < C4NUM; ++k) { @@ -123,9 +119,9 @@ void FullConnectionOpenCLKernel::PadWeight() { int src_ci = i * C4NUM + l; int src_co = j * C4NUM + k; if (src_ci < ci && src_co < co) { - int originId = baseAB + src_ci * co + src_co; + int originId = baseHW * ci * co + src_ci * co + src_co; if (transposeB) { - originId = baseAB + src_co * ci + src_ci; + originId = src_co * ci * h * w + baseHW * ci + src_ci; } if (enable_fp16_) { if (!isModelFp16) {