matmul support act weight

pull/10934/head
chenzupeng 4 years ago
parent 5e666e7c21
commit 3897ca3bbe

@ -2,7 +2,7 @@
#define C4NUM 4
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void MatMul_NHWC4_2d(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,
__kernel void MatMul_2d(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,
int4 in_shape, int4 out_shape) {
int gidx = get_global_id(0); // CO4
int gidz = get_global_id(2); // N
@ -32,7 +32,7 @@ __kernel void MatMul_NHWC4_2d(__read_only image2d_t input, __write_only image2d_
}
}
__kernel void MatMul_NHWC4_4d(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,
__kernel void MatMul_4d(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,
int4 in_shape, int4 out_shape) {
int gidx = get_global_id(0); // CO4
int gidy = get_global_id(1); // N * H * 4
@ -65,3 +65,77 @@ __kernel void MatMul_NHWC4_4d(__read_only image2d_t input, __write_only image2d_
WRITE_IMAGE(output, (int2)(gidz * co4 + gidx, nh_index), result);
}
}
__kernel void MatMulActWeightTransposeB_4d(__read_only image2d_t input, __write_only image2d_t output,
__read_only image2d_t weight, int4 in_shape, int4 out_shape) {
int gidx = get_global_id(0); // CO4
int gidy = get_global_id(1); // N * H * 4
int gidz = get_global_id(2); // W
int lidx = get_local_id(0);
int lidy = get_local_id(1);
int ci4 = UP_DIV(in_shape.w, C4NUM);
int co4 = UP_DIV(out_shape.w, C4NUM);
int n = out_shape.x;
int h = out_shape.y;
int w = out_shape.z;
int nh_index = gidy / 4;
bool inside = gidx < co4 && gidz < w && nh_index < n * h;
FLT4 result = (FLT4)(0.0f);
for (uint i = lidy; i < ci4 && inside; i += 4) {
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(gidz * ci4 + i, nh_index));
FLT4 weight_value0 = READ_IMAGE(weight, smp_zero, (int2)(gidx * 4 * ci4 + i, nh_index));
result.x += dot(v, weight_value0);
FLT4 weight_value1 = READ_IMAGE(weight, smp_zero, (int2)((gidx * 4 + 1) * ci4 + i, nh_index));
result.y += dot(v, weight_value1);
FLT4 weight_value2 = READ_IMAGE(weight, smp_zero, (int2)((gidx * 4 + 2) * ci4 + i, nh_index));
result.z += dot(v, weight_value2);
FLT4 weight_value3 = READ_IMAGE(weight, smp_zero, (int2)((gidx * 4 + 3) * ci4 + i, nh_index));
result.w += dot(v, weight_value3);
}
__local FLT4 temp[32][4];
temp[lidx][lidy] = result;
barrier(CLK_LOCAL_MEM_FENCE);
if (lidy == 0 && inside) {
result += temp[lidx][1];
result += temp[lidx][2];
result += temp[lidx][3];
WRITE_IMAGE(output, (int2)(gidz * co4 + gidx, nh_index), result);
}
}
__kernel void MatMulActWeight_4d(__read_only image2d_t input, __write_only image2d_t output,
__read_only image2d_t weight, int4 in_shape, int4 out_shape) {
int gidx = get_global_id(0); // CO4
int gidy = get_global_id(1); // N * H * 4
int gidz = get_global_id(2); // W
int lidx = get_local_id(0);
int lidy = get_local_id(1);
int ci4 = UP_DIV(in_shape.w, C4NUM);
int co4 = UP_DIV(out_shape.w, C4NUM);
int n = out_shape.x;
int h = out_shape.y;
int w = out_shape.z;
int nh_index = gidy / 4;
bool inside = gidx < co4 && gidz < w && nh_index < n * h;
FLT4 result = (FLT4)(0.0f);
for (uint i = lidy; i < ci4 && inside; i += 4) {
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(gidz * ci4 + i, nh_index));
FLT4 weight_value0 = READ_IMAGE(weight, smp_zero, (int2)(i * 4 * co4 + gidx, nh_index));
result += v.x * weight_value0;
FLT4 weight_value1 = READ_IMAGE(weight, smp_zero, (int2)((i * 4 + 1) * co4 + gidx, nh_index));
result += v.y * weight_value1;
FLT4 weight_value2 = READ_IMAGE(weight, smp_zero, (int2)((i * 4 + 2) * co4 + gidx, nh_index));
result += v.z * weight_value2;
FLT4 weight_value3 = READ_IMAGE(weight, smp_zero, (int2)((i * 4 + 3) * co4 + gidx, nh_index));
result += v.w * weight_value3;
}
__local FLT4 temp[32][4];
temp[lidx][lidy] = result;
barrier(CLK_LOCAL_MEM_FENCE);
if (lidy == 0 && inside) {
result += temp[lidx][1];
result += temp[lidx][2];
result += temp[lidx][3];
WRITE_IMAGE(output, (int2)(gidz * co4 + gidx, nh_index), result);
}
}

@ -42,21 +42,25 @@ int MatMulOpenCLKernel::CheckSpecs() {
return mindspore::lite::RET_ERROR;
}
transposeB = param->b_transpose_;
act_weight_ = !in_tensors_[1]->IsConst();
enable_fp16_ = ocl_runtime_->GetFp16Enable();
if (in_tensors_[0]->shape().size() != out_tensors_[0]->shape().size() || in_tensors_[0]->shape().size() < 2 ||
in_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4.";
return mindspore::lite::RET_ERROR;
}
if (!in_tensors_.at(kWeightIndex)->IsConst()) {
MS_LOG(ERROR) << "Matmul don't support non-constant filter yet.";
return RET_ERROR;
}
return RET_OK;
}
int MatMulOpenCLKernel::Prepare() {
std::string kernel_name = "MatMul_NHWC4";
std::string kernel_name = "MatMul";
if (act_weight_) {
if (transposeB) {
kernel_name = "MatMulActWeightTransposeB";
} else {
kernel_name = "MatMulActWeight";
}
}
dims = in_tensors_[0]->shape().size();
for (int i = 0; i < dims; i++) {
inShape[MAX_DIMS - dims + i] = in_tensors_[0]->shape()[i];
@ -83,6 +87,9 @@ int MatMulOpenCLKernel::Prepare() {
}
int MatMulOpenCLKernel::InitWeights() {
if (act_weight_) {
return RET_OK;
}
// ABMCI @ ABCICO = ABMCO
auto ret = DequantWeight();
if (ret != RET_OK) {
@ -164,7 +171,11 @@ void MatMulOpenCLKernel::SetConstArgs() {
int arg_count = 2;
cl_int4 in_shape = {inShape[0], inShape[1], inShape[2], inShape[3]};
cl_int4 out_shape = {outShape[0], outShape[1], outShape[2], outShape[3]};
if (act_weight_) {
arg_count++;
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF);
}
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_shape);
}
@ -174,6 +185,9 @@ int MatMulOpenCLKernel::Run() {
int arg_count = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->data_c());
if (act_weight_) {
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[1]->data_c());
}
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return mindspore::lite::RET_OK;
}

@ -48,6 +48,7 @@ class MatMulOpenCLKernel : public OpenCLKernel {
static constexpr int MAX_DIMS{4}; // max supported matmul dims
std::vector<int> inShape{std::vector<int>(MAX_DIMS, 1)};
std::vector<int> outShape{std::vector<int>(MAX_DIMS, 1)};
bool act_weight_{false};
};
} // namespace mindspore::kernel

@ -89,4 +89,44 @@ TEST_F(TestOpenCL_MatMul, 3D) {
param, fp16_enable);
}
}
TEST_F(TestOpenCL_MatMul, ActWeightTransposeB3D) {
int a = 2;
int m = 2;
int ci = 5;
int co = 3;
std::vector<int> input_shape = {a, m, ci};
std::vector<int> output_shape = {a, m, co};
std::vector<int> weight_shape = {a, co, ci};
float input_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float weight_data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30};
float output_data[] = {15, 40, 65, 15, 40, 65, 90, 115, 140, 90, 115, 140};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter();
TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, VAR}}, {output_shape, output_data}, param,
fp16_enable);
}
}
TEST_F(TestOpenCL_MatMul, ActWeight3D) {
int a = 2;
int m = 2;
int ci = 5;
int co = 3;
std::vector<int> input_shape = {a, m, ci};
std::vector<int> output_shape = {a, m, co};
std::vector<int> weight_shape = {a, ci, co};
float input_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float weight_data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30};
float output_data[] = {35, 40, 45, 35, 40, 45, 110, 115, 120, 110, 115, 120};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(false, false);
TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, VAR}}, {output_shape, output_data}, param,
fp16_enable);
}
}
} // namespace mindspore::lite::opencl::test

Loading…
Cancel
Save