|
|
|
@ -14,20 +14,35 @@ __kernel void SoftMaxAxis3_NHWC4(__read_only image2d_t input, __write_only image
|
|
|
|
|
|
|
|
|
|
if (X >= H || Y >= W) return;
|
|
|
|
|
|
|
|
|
|
// get max
|
|
|
|
|
float4 last = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X)));
|
|
|
|
|
float input_max = last.x;
|
|
|
|
|
if (mask.y > 0.5f) input_max = max(input_max, last.y);
|
|
|
|
|
if (mask.z > 0.5f) input_max = max(input_max, last.z);
|
|
|
|
|
if (mask.w > 0.5f) input_max = max(input_max, last.w);
|
|
|
|
|
for (int d = 0; d < C4 - 1; ++d) {
|
|
|
|
|
float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X)));
|
|
|
|
|
input_max = max(input_max, t.x);
|
|
|
|
|
input_max = max(input_max, t.y);
|
|
|
|
|
input_max = max(input_max, t.z);
|
|
|
|
|
input_max = max(input_max, t.w);
|
|
|
|
|
}
|
|
|
|
|
float4 input_max_f4 = (float4)(input_max, input_max, input_max, input_max);
|
|
|
|
|
|
|
|
|
|
float sum = 0.0f;
|
|
|
|
|
for (int d = 0; d < C4 - 1; ++d) {
|
|
|
|
|
float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X)));
|
|
|
|
|
sum += dot(exp(t), (float4)(1.f));
|
|
|
|
|
sum += dot(exp(t - input_max_f4), (float4)(1.f));
|
|
|
|
|
}
|
|
|
|
|
float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X)));
|
|
|
|
|
sum += dot(exp(t), mask);
|
|
|
|
|
sum += dot(exp(min(t - input_max_f4, 0)), mask);
|
|
|
|
|
for (int d = 0; d < C4 - 1; ++d) {
|
|
|
|
|
float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X)));
|
|
|
|
|
result = exp(result) / sum;
|
|
|
|
|
result = exp(result - input_max_f4) / sum;
|
|
|
|
|
WRITE_IMAGE(output, (int2)(Y * C4 + d, X), TO_FLT4(result));
|
|
|
|
|
}
|
|
|
|
|
float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X)));
|
|
|
|
|
result = exp(result) / sum;
|
|
|
|
|
result = exp(min(result - input_max_f4, 0)) / sum;
|
|
|
|
|
result = result * mask;
|
|
|
|
|
WRITE_IMAGE(output, (int2)(Y * C4 + C4 - 1, X), TO_FLT4(result));
|
|
|
|
|
}
|
|
|
|
|