From 5d82670532382af298616a04d41fc959de6afaa2 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Wed, 30 Sep 2020 11:06:38 +0800 Subject: [PATCH] fix bug in arithmetic --- .../src/runtime/kernel/opencl/cl/arithmetic.cl | 13 +++++++++++++ .../runtime/kernel/opencl/kernel/arithmetic.cc | 16 ++++++++++++++++ .../src/runtime/kernel/opencl/kernel/scale.cc | 4 ++-- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl index 30357ba516..13247d95c2 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl @@ -15,6 +15,19 @@ __kernel void ElementAdd_IMG(__read_only image2d_t input_a, __read_only image2d_ WRITE_IMAGE(output, (int2)(X, Y), a + b); } +__kernel void ElementAddReLU_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b, + __write_only image2d_t output, const int2 output_shape) { + int X = get_global_id(0); + int Y = get_global_id(1); + if (X >= output_shape.x || Y >= output_shape.y) { + return; + } + + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y)); + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y)); + WRITE_IMAGE(output, (int2)(X, Y), max(a + b, (FLT4)(0.f))); +} + __kernel void ElementSub_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output, const int2 output_shape) { int X = get_global_id(0); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 59a520d67b..cd06b8666b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -262,6 +262,22 @@ int ArithmeticOpenCLKernel::Init() { return RET_ERROR; } + switch (arithmetic_parameter->activation_type_) { + case schema::ActivationType_NO_ACTIVATION: + break; + case schema::ActivationType_RELU: + if (op_parameter_->type_ == PrimitiveType_Add && element_flag_) { + kernel_name += "ReLU"; + } else { + MS_LOG(ERROR) << "Only support ElementAdd + ReLU"; + return RET_ERROR; + } + break; + default: + MS_LOG(ERROR) << "Error activation type " << arithmetic_parameter->activation_type_; + return RET_ERROR; + } + lite::STATUS error_code = RET_OK; #ifdef PROGRAM_WITH_IL kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc index 97ccf03f68..efe3e4fcf8 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc @@ -115,8 +115,8 @@ int ScaleOpenCLKernel::InitBuffer() { std::vector img_size; GetImageSize(0, &img_size); if (in_tensors_[1]->shape().size() == 1 && axis_ == 3) { - img_size[0] = 1; - img_size[1] = UP_DIV(in_tensors_[1]->shape()[0], C4NUM); + img_size[1] = 1; + img_size[0] = UP_DIV(in_tensors_[1]->shape()[0], C4NUM); scale_ptr_ = allocator->CreateImageFromHost(in_tensors_[1]->data_c(), in_tensors_[1]->ElementsNum(), img_size); offset_ptr_ = allocator->CreateImageFromHost(in_tensors_[2]->data_c(), in_tensors_[2]->ElementsNum(), img_size); return RET_OK;