|
|
|
@ -123,36 +123,27 @@ __kernel void Winograd4x4To36(__read_only image2d_t input, __write_only image2d_
|
|
|
|
|
|
|
|
|
|
constant FLT *Bt_row = Bt + row * 6;
|
|
|
|
|
FLT4 BtD_row[6] = {0};
|
|
|
|
|
for (int y = 0; y < 6; y++) {
|
|
|
|
|
int ih = tile_y * 4 - PAD + y;
|
|
|
|
|
|
|
|
|
|
// Format_NHWC4
|
|
|
|
|
int y_idx = ih;
|
|
|
|
|
// Format_NC4HW4
|
|
|
|
|
// if (ih < 0 || ih >= IH) { continue;}
|
|
|
|
|
// int y_idx = slice * IH + ih;
|
|
|
|
|
|
|
|
|
|
int ih = tile_y * 4 - PAD;
|
|
|
|
|
int iw = tile_x * 4 - PAD;
|
|
|
|
|
for (int y = 0; y < 6; y++) {
|
|
|
|
|
int x_idx = iw * SLICES + slice;
|
|
|
|
|
for (int x = 0; x < 6; x++) {
|
|
|
|
|
int iw = tile_x * 4 - PAD + x;
|
|
|
|
|
|
|
|
|
|
// Format_NHWC4
|
|
|
|
|
if (iw < 0 || iw >= IW) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
int x_idx = iw * SLICES + slice;
|
|
|
|
|
// Format_NC4HW4
|
|
|
|
|
// int x_idx = iw;
|
|
|
|
|
|
|
|
|
|
BtD_row[x] += Bt_row[y] * READ_IMAGE(input, smp_zero, (int2)(x_idx, y_idx));
|
|
|
|
|
// no need to check iw: because slice is in [0, SLICES). when iw<0, x_idx<0; iw>=IW, x_idx>=IW*SLICES
|
|
|
|
|
// if (iw < 0 || iw >= IW) { continue; }
|
|
|
|
|
BtD_row[x] += Bt_row[y] * READ_IMAGE(input, smp_zero, (int2)(x_idx, ih));
|
|
|
|
|
x_idx += SLICES;
|
|
|
|
|
}
|
|
|
|
|
ih++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int y_idx = slice * 36 + row * 6;
|
|
|
|
|
for (int y = 0; y < 6; y++) {
|
|
|
|
|
FLT4 acc = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
|
|
|
|
|
for (int x = 0; x < 6; x++) {
|
|
|
|
|
acc += BtD_row[x] * Bt[y * 6 + x];
|
|
|
|
|
}
|
|
|
|
|
WRITE_IMAGE(output, (int2)(tile_xy, slice * 36 + (row * 6 + y)), acc); // CH W H=36
|
|
|
|
|
WRITE_IMAGE(output, (int2)(tile_xy, y_idx + y), acc); // CH W H=36
|
|
|
|
|
}
|
|
|
|
|
#undef PAD
|
|
|
|
|
}
|
|
|
|
@ -247,36 +238,36 @@ __kernel void Winograd36To4x4(__read_only image2d_t input, __write_only image2d_
|
|
|
|
|
|
|
|
|
|
constant FLT *At_row = At + row * 6;
|
|
|
|
|
FLT4 AtM_row[6] = {0};
|
|
|
|
|
for (int y = 0; y < 6; y++) {
|
|
|
|
|
for (int x = 0; x < 6; x++) {
|
|
|
|
|
AtM_row[x] += At_row[y] * READ_IMAGE(input, smp_zero, (int2)(tile_xy, slice * 36 + y * 6 + x));
|
|
|
|
|
for (int y = 0, idx = slice * 36; y < 6; y++) {
|
|
|
|
|
for (int x = 0; x < 6; x++, idx++) {
|
|
|
|
|
AtM_row[x] += At_row[y] * READ_IMAGE(input, smp_zero, (int2)(tile_xy, idx));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int TILE_X = UP_DIV(OW, 4);
|
|
|
|
|
for (int x = 0; x < 4; x++) {
|
|
|
|
|
int tile_x = tile_xy % TILE_X;
|
|
|
|
|
int tile_y = tile_xy / TILE_X;
|
|
|
|
|
int oh = tile_y * 4 + row;
|
|
|
|
|
int ow = tile_x * 4;
|
|
|
|
|
int x_idx = ow * SLICES + slice;
|
|
|
|
|
|
|
|
|
|
for (int x = 0, idx = 0; x < 4; x++) {
|
|
|
|
|
FLT4 acc = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
|
|
|
|
|
for (int y = 0; y < 6; y++) {
|
|
|
|
|
acc += AtM_row[y] * At[x * 6 + y];
|
|
|
|
|
for (int y = 0; y < 6; y++, idx++) {
|
|
|
|
|
acc += AtM_row[y] * At[idx];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (bias) {
|
|
|
|
|
acc += bias[slice];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (act_type == ActType_Relu) {
|
|
|
|
|
acc = max(acc, (FLT4)(0.0f));
|
|
|
|
|
} else if (act_type == ActType_Relu6) {
|
|
|
|
|
acc = clamp(acc, (FLT4)(0.0f), (FLT4)(6.0f));
|
|
|
|
|
}
|
|
|
|
|
int tile_x = tile_xy % TILE_X;
|
|
|
|
|
int tile_y = tile_xy / TILE_X;
|
|
|
|
|
int ow = tile_x * 4 + x;
|
|
|
|
|
int oh = tile_y * 4 + row;
|
|
|
|
|
|
|
|
|
|
// Format_NHWC4
|
|
|
|
|
if (ow < OW) {
|
|
|
|
|
WRITE_IMAGE(output, (int2)(ow * SLICES + slice, oh), acc);
|
|
|
|
|
}
|
|
|
|
|
// Format_NC4HW4
|
|
|
|
|
// if (oh < OH) { WRITE_IMAGE(output, (int2)(ow, slice * OH + oh), acc);}
|
|
|
|
|
|
|
|
|
|
WRITE_IMAGE(output, (int2)(x_idx, oh), acc);
|
|
|
|
|
x_idx += SLICES;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|