!4139 add image2d format for concat op
Merge pull request !4139 from pengyongrong/concat_debug_prpull/4139/MERGE
commit
641601f7b8
@ -1,54 +1,44 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
__kernel void Concat(__global float *input0, __global float *input1, __global float *output, const int4 input_shape0,
|
||||
const int4 input_shape1, const int4 output_shape, const int axis) {
|
||||
uint oh = get_global_id(0);
|
||||
uint ow = get_global_id(1);
|
||||
uint oc = get_global_id(2);
|
||||
uint index_output;
|
||||
uint input_idx;
|
||||
if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) {
|
||||
return;
|
||||
#define FLT4 float4
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
__kernel void Concat(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d,
|
||||
__read_only image2d_t input1_image2d, int2 shared_int0, int4 shared_out) {
|
||||
int X = get_global_id(0); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int S = 0;
|
||||
if (X >= shared_out.y || Y >= shared_out.z) return;
|
||||
for (int i = 0; i < shared_int0.x; i++) {
|
||||
FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X)));
|
||||
write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0);
|
||||
S++;
|
||||
}
|
||||
if (axis == 3) {
|
||||
index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc;
|
||||
if (oc < input_shape0.w) {
|
||||
input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc;
|
||||
output[index_output] = input0[input_idx];
|
||||
} else if ((input_shape0.w <= oc) && oc < (input_shape0.w + input_shape1.w)) {
|
||||
input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w);
|
||||
output[index_output] = input1[input_idx];
|
||||
} else {
|
||||
output[index_output] = 0;
|
||||
}
|
||||
for (int i = 0; i < shared_int0.y; i++) {
|
||||
FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X)));
|
||||
write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1);
|
||||
S++;
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void Concat3input(__global float *input0, __global float *input1, __global float *input2,
|
||||
__global float *output, const int4 input_shape0, const int4 input_shape1,
|
||||
const int4 input_shape2, const int4 output_shape, const int axis) {
|
||||
uint oh = get_global_id(0);
|
||||
uint ow = get_global_id(1);
|
||||
uint oc = get_global_id(2);
|
||||
uint index_output;
|
||||
uint input_idx;
|
||||
if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) {
|
||||
return;
|
||||
__kernel void Concat3input(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d,
|
||||
__read_only image2d_t input1_image2d, __read_only image2d_t input2_image2d, int3 shared_int0,
|
||||
int4 shared_out) {
|
||||
int X = get_global_id(0); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int S = 0;
|
||||
if (X >= shared_out.y || Y >= shared_out.z) return;
|
||||
for (int i = 0; i < shared_int0.x; i++) {
|
||||
FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X)));
|
||||
write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0);
|
||||
S++;
|
||||
}
|
||||
index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc;
|
||||
if (oc < (input_shape0.w + input_shape1.w)) {
|
||||
if (oc < input_shape0.w) {
|
||||
input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc;
|
||||
output[index_output] = input0[input_idx];
|
||||
} else {
|
||||
input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w);
|
||||
output[index_output] = input1[input_idx];
|
||||
}
|
||||
} else {
|
||||
if ((input_shape0.w + input_shape1.w + input_shape2.w) <= oc) {
|
||||
output[index_output] = 0;
|
||||
} else {
|
||||
input_idx = (input_shape2.z * oh + ow) * input_shape2.w + (oc - input_shape0.w - input_shape1.w);
|
||||
output[index_output] = input2[input_idx];
|
||||
}
|
||||
for (int i = 0; i < shared_int0.y; i++) {
|
||||
FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X)));
|
||||
write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1);
|
||||
S++;
|
||||
}
|
||||
for (int i = 0; i < shared_int0.z; i++) {
|
||||
FLT4 result2 = read_imagef(input2_image2d, smp_none, (int2)((Y)*shared_int0.z + (i), (X)));
|
||||
write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result2);
|
||||
S++;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in new issue