conv2d transpose support 4x4 8x8 and fullconnection support c%4!=0

pull/7352/head
chenzupeng 4 years ago
parent aa605e23d5
commit b1aa1a1d17

@ -3,7 +3,7 @@ __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP
__kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global FLT16 *weight,
__read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size,
int2 stride, int2 padding, int4 src_size, int4 dst_size) {
int h = get_global_id(0);
int h = get_global_id(2);
int kh = h % 2;
int src_h = h / 2;
src_h = src_h * 2;
@ -11,7 +11,7 @@ __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global
int kw = w % 2;
int src_w = w / 2;
src_w = src_w * 2;
int co = get_global_id(2);
int co = get_global_id(0);
if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return;
FLT4 r0 = (FLT4)(0.f);
FLT4 r1 = (FLT4)(0.f);
@ -59,7 +59,7 @@ __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global
__kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __global FLT16 *weight,
__read_only image2d_t biases, __write_only image2d_t dst_data,
int2 kernel_size, int2 stride, int2 padding, int4 src_size, int4 dst_size) {
int h = get_global_id(0);
int h = get_global_id(2);
int kh = h % 2;
int src_h = h / 2;
src_h = src_h * 2;
@ -67,7 +67,7 @@ __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __globa
int kw = w % 2;
int src_w = w / 2;
src_w = src_w * 2;
int co = get_global_id(2);
int co = get_global_id(0);
if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return;
FLT4 r0 = (FLT4)(0.f);
FLT4 r1 = (FLT4)(0.f);
@ -111,3 +111,85 @@ __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __globa
WRITE_IMAGE(dst_data, (int2)(2 * src_w + kw + 2, co * dst_size.x + 2 * src_h + kh), r2);
WRITE_IMAGE(dst_data, (int2)(2 * src_w + kw + 2, co * dst_size.x + 2 * src_h + kh + 2), r3);
}
__kernel void conv2d_transpose_NHWC4(__read_only image2d_t src_data, __global FLT16 *weight,
__read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size,
int2 stride, int2 padding, int4 src_size, int4 dst_size) {
int dst_h = get_global_id(2);
int rem_h = dst_h % stride.x;
int ceil_h = dst_h / stride.x;
dst_h = ceil_h * stride.x * 2 + rem_h;
int dst_w = get_global_id(1);
int rem_w = dst_w % stride.y;
int ceil_w = dst_w / stride.y;
dst_w = ceil_w * stride.y * 2 + rem_w;
int dst_c = get_global_id(0);
if (dst_h >= dst_size.x || dst_w >= dst_size.y || dst_c >= dst_size.z) return;
int weight_base = dst_c * src_size.z * kernel_size.x * kernel_size.y;
FLT4 r0 = (FLT4)(0.f);
FLT4 r1 = (FLT4)(0.f);
FLT4 r2 = (FLT4)(0.f);
FLT4 r3 = (FLT4)(0.f);
int kh_start = dst_h + padding.x;
int kw_start = dst_w + padding.y;
int kh_end = kh_start - kernel_size.x;
int kw_end = kw_start - kernel_size.y;
int src_h = kh_start / stride.x;
int kh = src_h * stride.x;
int src_w = kw_start / stride.y;
int kw = src_w * stride.y;
for (; kh > kh_end; src_h -= 1, kh -= stride.x) {
int out0_src_h = src_h;
int out1_src_h = src_h + 1;
int kernel_h = kh_start - kh;
int src_w_copy = src_w;
int kw_copy = kw;
for (; kw_copy > kw_end; src_w_copy -= 1, kw_copy -= stride.y) {
int out0_src_w = src_w_copy;
int out1_src_w = src_w_copy + 1;
int kernel_w = kw_start - kw_copy;
int weight_offset = weight_base + (kernel_h * kernel_size.y + kernel_w) * src_size.z;
for (int ci = 0; ci < src_size.z; ++ci) {
FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(out0_src_w * src_size.z + ci, out0_src_h));
FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(out0_src_w * src_size.z + ci, out1_src_h));
FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(out1_src_w * src_size.z + ci, out0_src_h));
FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(out1_src_w * src_size.z + ci, out1_src_h));
FLT16 weight_cache = weight[weight_offset++];
r0 += x0.x * weight_cache.s0123;
r0 += x0.y * weight_cache.s4567;
r0 += x0.z * weight_cache.s89ab;
r0 += x0.w * weight_cache.scdef;
r1 += x1.x * weight_cache.s0123;
r1 += x1.y * weight_cache.s4567;
r1 += x1.z * weight_cache.s89ab;
r1 += x1.w * weight_cache.scdef;
r2 += x2.x * weight_cache.s0123;
r2 += x2.y * weight_cache.s4567;
r2 += x2.z * weight_cache.s89ab;
r2 += x2.w * weight_cache.scdef;
r3 += x3.x * weight_cache.s0123;
r3 += x3.y * weight_cache.s4567;
r3 += x3.z * weight_cache.s89ab;
r3 += x3.w * weight_cache.scdef;
}
}
}
FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(dst_c, 0));
r0 += bias_val;
r1 += bias_val;
r2 += bias_val;
r3 += bias_val;
WRITE_IMAGE(dst_data, (int2)(dst_w * dst_size.z + dst_c, dst_h), r0);
if (dst_h + stride.x < dst_size.x && dst_w < dst_size.y) {
WRITE_IMAGE(dst_data, (int2)(dst_w * dst_size.z + dst_c, dst_h + stride.x), r1);
}
if (dst_h < dst_size.x && dst_w + stride.y < dst_size.y) {
WRITE_IMAGE(dst_data, (int2)((dst_w + stride.y) * dst_size.z + dst_c, dst_h), r2);
}
if (dst_h + stride.x < dst_size.x && dst_w + stride.y < dst_size.y) {
WRITE_IMAGE(dst_data, (int2)((dst_w + stride.y) * dst_size.z + dst_c, dst_h + stride.x), r3);
}
}

@ -20,7 +20,7 @@
#include "nnacl/fp32/common_func.h"
#include "src/kernel_registry.h"
#ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl.inc"
#include "src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc"
#endif
using mindspore::kernel::KERNEL_ARCH::kGPU;
@ -31,22 +31,20 @@ namespace mindspore::kernel {
int Conv2dTransposeOpenCLKernel::Init() {
ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_);
if (param->kernel_h_ != 2 || param->kernel_w_ != 2 || param->stride_h_ != 2 || param->stride_w_ != 2) {
MS_LOG(ERROR) << "only support kh=kw=2 and stride_h=stride_w=2.";
if (param->pad_l_ != param->pad_r_ || param->kernel_h_ - param->stride_h_ != 2 * param->pad_l_ ||
param->pad_u_ != param->pad_d_ || param->kernel_w_ - param->stride_w_ != 2 * param->pad_u_) {
MS_LOG(ERROR) << "only support kernel - stride == 2 * pad";
return RET_ERROR;
}
if (param->pad_u_ != 0 || param->pad_l_ != 0) {
MS_LOG(ERROR) << "only support pad =0.";
return RET_ERROR;
}
std::string kernel_name = "conv2d_transpose2x2_" + std::string(EnumNameFormat(op_format_));
std::string kernel_name = "conv2d_transpose";
kernel_name += "_" + std::string(EnumNameFormat(op_format_));
enable_fp16_ = ocl_runtime_->GetFp16Enable();
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
std::string source = conv2d_transpose2x2_source;
std::string source = conv2d_transpose_source;
std::set<std::string> build_options;
std::string program_name = "conv2d_transpose2x2";
std::string program_name = "conv2d_transpose";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
@ -181,19 +179,22 @@ int Conv2dTransposeOpenCLKernel::Run() {
int co4 = UP_DIV(co, C4NUM);
int kh = param->kernel_h_;
int kw = param->kernel_w_;
int pad = param->pad_u_;
int pad_h = param->pad_l_;
int pad_w = param->pad_u_;
int stride_h = param->stride_h_;
int stride_w = param->stride_w_;
int oh = out_tensors_[0]->shape()[1];
int ow = out_tensors_[0]->shape()[2];
int h = in_tensors_[0]->shape()[1];
int w = in_tensors_[0]->shape()[2];
// local size should less than MAX_GROUP_SIZE
std::vector<size_t> local = {16, 1, 16};
std::vector<size_t> global = {UP_ROUND((size_t)UP_ROUND(oh / 2, 2), local[0]),
UP_ROUND((size_t)UP_ROUND(ow / 2, 2), local[1]), UP_ROUND(co4, local[2])};
std::vector<size_t> global = {UP_ROUND(co4, local[0]), UP_ROUND((size_t)UP_ROUND(ow / 2, stride_w), local[1]),
UP_ROUND((size_t)UP_ROUND(oh / 2, stride_h), local[2])};
cl_int2 kernel_size = {kh, kw};
cl_int2 stride = {2, 2};
cl_int2 padding = {pad, pad};
cl_int2 stride = {stride_h, stride_w};
cl_int2 padding = {pad_h, pad_w};
cl_int4 src_size = {h, w, UP_DIV(ci, C4NUM), 1};
cl_int4 dst_size = {oh, ow, UP_DIV(co, C4NUM), 1};
int arg_cnt = 0;

@ -47,10 +47,6 @@ int FullConnectionOpenCLKernel::Init() {
return RET_ERROR;
}
if (in_tensors_[0]->shape().size() == 4) {
if (in_tensors_[0]->shape()[3] % C4NUM != 0) {
MS_LOG(ERROR) << "fullconnection only support input shape channel % 4 = 0 if input shape size = 4";
return RET_ERROR;
}
inShape = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2],
in_tensors_[0]->shape()[3]};
} else {
@ -92,30 +88,30 @@ int FullConnectionOpenCLKernel::ReSize() { return RET_OK; }
void FullConnectionOpenCLKernel::PadWeight() {
// ABMCI @ ABCICO = ABMCO
auto allocator = ocl_runtime_->GetAllocator();
int ci = inShape[1] * inShape[2] * inShape[3];
int ci = inShape[3];
int ci4 = UP_DIV(ci, C4NUM);
int co = outShape[1];
int co4 = UP_DIV(co, C4NUM);
int a = 1;
int b = 1;
int h = inShape[1];
int w = inShape[2];
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size);
padWeight_ = allocator->Malloc(h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size);
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
auto padWeightFp32 = reinterpret_cast<float *>(padWeight_);
auto padWeightFp16 = reinterpret_cast<float16_t *>(padWeight_);
memset(padWeight_, 0x00, a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size);
memset(padWeight_, 0x00, h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size);
auto originWeightFp32 = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c());
auto originWeightFp16 = reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->data_c());
bool isModelFp16 = in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16;
// pad weight
// ABCICO -> AB(CI4)(CO4)(4 from CO)(4 from CI)
// if tranposeB, ABCOCI -> AB(CI4)(CO4)(4 from CO)(4 from CI)
// HWCICO -> (HWCI4)(CO4)(4 from CO)(4 from CI)
// if tranposeB, COHWCI -> (HWCI4)(CO4)(4 from CO)(4 from CI)
int index = 0;
for (int aa = 0; aa < a; aa++) {
for (int bb = 0; bb < b; bb++) {
int baseAB = (aa * b + bb) * ci * co;
for (int hh = 0; hh < h; hh++) {
for (int ww = 0; ww < w; ww++) {
int baseHW = hh * w + ww;
for (int i = 0; i < ci4; ++i) {
for (int j = 0; j < co4; ++j) {
for (int k = 0; k < C4NUM; ++k) {
@ -123,9 +119,9 @@ void FullConnectionOpenCLKernel::PadWeight() {
int src_ci = i * C4NUM + l;
int src_co = j * C4NUM + k;
if (src_ci < ci && src_co < co) {
int originId = baseAB + src_ci * co + src_co;
int originId = baseHW * ci * co + src_ci * co + src_co;
if (transposeB) {
originId = baseAB + src_co * ci + src_ci;
originId = src_co * ci * h * w + baseHW * ci + src_ci;
}
if (enable_fp16_) {
if (!isModelFp16) {

Loading…
Cancel
Save