parent
45ae76e86a
commit
39e4900e4b
@ -1,52 +1,59 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#define FLT half
|
||||
#define FLT4 half4
|
||||
#define FLT16 half16
|
||||
__kernel void conv2d_transpose2x2(__global FLT4 *inputx, __global FLT16 *weight, __global FLT4 *bias,
|
||||
__global FLT4 *output, int2 kernel_size, int2 stride, int2 padding, int4 src_size,
|
||||
int4 dst_size) {
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases,
|
||||
__write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding,
|
||||
int4 src_size, int4 dst_size) {
|
||||
int h = get_global_id(0);
|
||||
int kh = h % 2;
|
||||
int src_h = h / 2;
|
||||
src_h = src_h * 2;
|
||||
int w = get_global_id(1);
|
||||
int kw = w % 2;
|
||||
int src_w = w / 2;
|
||||
src_w = src_w * 2;
|
||||
int co = get_global_id(2);
|
||||
if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return;
|
||||
FLT4 r0 = (FLT4)(0.f);
|
||||
FLT4 r1 = (FLT4)(0.f);
|
||||
FLT4 r2 = (FLT4)(0.f);
|
||||
FLT4 r3 = (FLT4)(0.f);
|
||||
int base_x = (h * src_size.y + w) * src_size.z;
|
||||
int base_w = co * src_size.z;
|
||||
int base_w = (co * 4 + kh + kw * 2) * src_size.z;
|
||||
for (int ci = 0; ci < src_size.z; ++ci) {
|
||||
FLT4 x = inputx[base_x + ci];
|
||||
FLT16 w0 = weight[(base_w + ci) * 4];
|
||||
FLT16 w1 = weight[(base_w + ci) * 4 + 1];
|
||||
FLT16 w2 = weight[(base_w + ci) * 4 + 2];
|
||||
FLT16 w3 = weight[(base_w + ci) * 4 + 3];
|
||||
r0 += x.x * w0.s0123;
|
||||
r0 += x.y * w0.s4567;
|
||||
r0 += x.z * w0.s89ab;
|
||||
r0 += x.w * w0.scdef;
|
||||
FLT4 x0 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h));
|
||||
FLT4 x1 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1));
|
||||
FLT4 x2 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h));
|
||||
FLT4 x3 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1));
|
||||
FLT16 weight_cache = weight[base_w++];
|
||||
r0 += x0.x * weight_cache.s0123;
|
||||
r0 += x0.y * weight_cache.s4567;
|
||||
r0 += x0.z * weight_cache.s89ab;
|
||||
r0 += x0.w * weight_cache.scdef;
|
||||
|
||||
r1 += x.x * w1.s0123;
|
||||
r1 += x.y * w1.s4567;
|
||||
r1 += x.z * w1.s89ab;
|
||||
r1 += x.w * w1.scdef;
|
||||
r1 += x1.x * weight_cache.s0123;
|
||||
r1 += x1.y * weight_cache.s4567;
|
||||
r1 += x1.z * weight_cache.s89ab;
|
||||
r1 += x1.w * weight_cache.scdef;
|
||||
|
||||
r2 += x.x * w2.s0123;
|
||||
r2 += x.y * w2.s4567;
|
||||
r2 += x.z * w2.s89ab;
|
||||
r2 += x.w * w2.scdef;
|
||||
r2 += x2.x * weight_cache.s0123;
|
||||
r2 += x2.y * weight_cache.s4567;
|
||||
r2 += x2.z * weight_cache.s89ab;
|
||||
r2 += x2.w * weight_cache.scdef;
|
||||
|
||||
r3 += x.x * w3.s0123;
|
||||
r3 += x.y * w3.s4567;
|
||||
r3 += x.z * w3.s89ab;
|
||||
r3 += x.w * w3.scdef;
|
||||
r3 += x3.x * weight_cache.s0123;
|
||||
r3 += x3.y * weight_cache.s4567;
|
||||
r3 += x3.z * weight_cache.s89ab;
|
||||
r3 += x3.w * weight_cache.scdef;
|
||||
}
|
||||
r0 += bias[co];
|
||||
r1 += bias[co];
|
||||
r2 += bias[co];
|
||||
r3 += bias[co];
|
||||
output[((2 * h + 0) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r0;
|
||||
output[((2 * h + 0) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r1;
|
||||
output[((2 * h + 1) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r2;
|
||||
output[((2 * h + 1) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r3;
|
||||
}
|
||||
FLT4 bias_val = read_imagef(biases, smp_zero, (int2)(co, 0));
|
||||
r0 += bias_val;
|
||||
r1 += bias_val;
|
||||
r2 += bias_val;
|
||||
r3 += bias_val;
|
||||
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0);
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1);
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2);
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3);
|
||||
}
|
||||
|
@ -1,51 +1,59 @@
|
||||
#define FLT float
|
||||
#define FLT4 float4
|
||||
#define FLT16 float16
|
||||
__kernel void conv2d_transpose2x2(__global FLT4 *inputx, __global FLT16 *weight, __global FLT4 *bias,
|
||||
__global FLT4 *output, int2 kernel_size, int2 stride, int2 padding, int4 src_size,
|
||||
int4 dst_size) {
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases,
|
||||
__write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding,
|
||||
int4 src_size, int4 dst_size) {
|
||||
int h = get_global_id(0);
|
||||
int kh = h % 2;
|
||||
int src_h = h / 2;
|
||||
src_h = src_h * 2;
|
||||
int w = get_global_id(1);
|
||||
int kw = w % 2;
|
||||
int src_w = w / 2;
|
||||
src_w = src_w * 2;
|
||||
int co = get_global_id(2);
|
||||
if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return;
|
||||
FLT4 r0 = (FLT4)(0.f);
|
||||
FLT4 r1 = (FLT4)(0.f);
|
||||
FLT4 r2 = (FLT4)(0.f);
|
||||
FLT4 r3 = (FLT4)(0.f);
|
||||
int base_x = (h * src_size.y + w) * src_size.z;
|
||||
int base_w = co * src_size.z;
|
||||
int base_w = (co * 4 + kh + kw * 2) * src_size.z;
|
||||
for (int ci = 0; ci < src_size.z; ++ci) {
|
||||
FLT4 x = inputx[base_x + ci];
|
||||
FLT16 w0 = weight[(base_w + ci) * 4];
|
||||
FLT16 w1 = weight[(base_w + ci) * 4 + 1];
|
||||
FLT16 w2 = weight[(base_w + ci) * 4 + 2];
|
||||
FLT16 w3 = weight[(base_w + ci) * 4 + 3];
|
||||
r0 += x.x * w0.s0123;
|
||||
r0 += x.y * w0.s4567;
|
||||
r0 += x.z * w0.s89ab;
|
||||
r0 += x.w * w0.scdef;
|
||||
FLT4 x0 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h));
|
||||
FLT4 x1 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1));
|
||||
FLT4 x2 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h));
|
||||
FLT4 x3 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1));
|
||||
FLT16 weight_cache = weight[base_w++];
|
||||
r0 += x0.x * weight_cache.s0123;
|
||||
r0 += x0.y * weight_cache.s4567;
|
||||
r0 += x0.z * weight_cache.s89ab;
|
||||
r0 += x0.w * weight_cache.scdef;
|
||||
|
||||
r1 += x.x * w1.s0123;
|
||||
r1 += x.y * w1.s4567;
|
||||
r1 += x.z * w1.s89ab;
|
||||
r1 += x.w * w1.scdef;
|
||||
r1 += x1.x * weight_cache.s0123;
|
||||
r1 += x1.y * weight_cache.s4567;
|
||||
r1 += x1.z * weight_cache.s89ab;
|
||||
r1 += x1.w * weight_cache.scdef;
|
||||
|
||||
r2 += x.x * w2.s0123;
|
||||
r2 += x.y * w2.s4567;
|
||||
r2 += x.z * w2.s89ab;
|
||||
r2 += x.w * w2.scdef;
|
||||
r2 += x2.x * weight_cache.s0123;
|
||||
r2 += x2.y * weight_cache.s4567;
|
||||
r2 += x2.z * weight_cache.s89ab;
|
||||
r2 += x2.w * weight_cache.scdef;
|
||||
|
||||
r3 += x.x * w3.s0123;
|
||||
r3 += x.y * w3.s4567;
|
||||
r3 += x.z * w3.s89ab;
|
||||
r3 += x.w * w3.scdef;
|
||||
r3 += x3.x * weight_cache.s0123;
|
||||
r3 += x3.y * weight_cache.s4567;
|
||||
r3 += x3.z * weight_cache.s89ab;
|
||||
r3 += x3.w * weight_cache.scdef;
|
||||
}
|
||||
r0 += bias[co];
|
||||
r1 += bias[co];
|
||||
r2 += bias[co];
|
||||
r3 += bias[co];
|
||||
output[((2 * h + 0) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r0;
|
||||
output[((2 * h + 0) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r1;
|
||||
output[((2 * h + 1) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r2;
|
||||
output[((2 * h + 1) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r3;
|
||||
}
|
||||
FLT4 bias_val = read_imagef(biases, smp_zero, (int2)(co, 0));
|
||||
r0 += bias_val;
|
||||
r1 += bias_val;
|
||||
r2 += bias_val;
|
||||
r3 += bias_val;
|
||||
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0);
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1);
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2);
|
||||
write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3);
|
||||
}
|
||||
|
Loading…
Reference in new issue