From 56d32b4e77c9edca3c3c978761d1ae1e1306db5a Mon Sep 17 00:00:00 2001 From: liuzhongkai Date: Thu, 13 Aug 2020 01:47:15 -0700 Subject: [PATCH] modify relu sigmoid leaky_relu in activation --- .../kernel/opencl/cl/fp32/activation.cl | 70 +++++++ .../kernel/opencl/cl/fp32/leaky_relu.cl | 28 --- .../kernel/opencl/kernel/activation.cc | 146 ++++++++++++++ .../kernel/{leaky_relu.h => activation.h} | 29 +-- .../kernel/opencl/kernel/arithmetic.cc | 3 +- .../kernel/opencl/kernel/conv2d_transpose.cc | 3 +- .../kernel/opencl/kernel/depthwise_conv2d.cc | 3 +- .../kernel/opencl/kernel/leaky_relu.cc | 122 ------------ .../runtime/kernel/opencl/kernel/matmul.cc | 3 +- .../runtime/kernel/opencl/kernel/pooling2d.cc | 2 +- .../runtime/kernel/opencl/kernel/softmax.cc | 2 +- .../runtime/kernel/opencl/kernel/transpose.cc | 2 +- mindspore/lite/test/CMakeLists.txt | 6 +- .../runtime/kernel/opencl/activation_tests.cc | 185 ++++++++++++++++++ .../runtime/kernel/opencl/leakyrelu_tests.cc | 110 ----------- 15 files changed, 431 insertions(+), 283 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/activation.cl delete mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc rename mindspore/lite/src/runtime/kernel/opencl/kernel/{leaky_relu.h => activation.h} (54%) delete mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc delete mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/activation.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/activation.cl new file mode 100644 index 0000000000..e9c1f2519f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/activation.cl @@ -0,0 +1,70 @@ +#pragma OPENCL EXTENSION cl_arm_printf : enable + +#define SLICES 4 +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define FLT4 float4 +#define MIN(X, Y) (X < Y ? X : Y) +#define READ_FLT4 read_imagef +#define WRITE_FLT4 write_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void ReluScalar(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, + const float alpha) { + int C = input_shape.w; // channel size + int Y = get_global_id(0); // height id + int X = get_global_id(1); // weight id + for (int num = 0; num < UP_DIV(C, SLICES); ++num) { + FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC + FLT4 tmp; + tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha; + tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha; + tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha; + tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha; + WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC + } +} + +__kernel void Relu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { + int C = input_shape.w; // channel size + int Y = get_global_id(0); // height id + int X = get_global_id(1); // weight id + for (int num = 0; num < UP_DIV(C, SLICES); ++num) { + FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC + FLT4 tmp; + tmp.x = in_c4.x >= 0 ? in_c4.x : 0; + tmp.y = in_c4.y >= 0 ? in_c4.y : 0; + tmp.z = in_c4.z >= 0 ? in_c4.z : 0; + tmp.w = in_c4.w >= 0 ? in_c4.w : 0; + WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC + } +} + +__kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { + int C = input_shape.w; // channel size + int Y = get_global_id(0); // height id + int X = get_global_id(1); // weight id + for (int num = 0; num < UP_DIV(C, SLICES); ++num) { + FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC + FLT4 tmp; + tmp.x = in_c4.x >= 0 ? MIN(in_c4.x, 6) : 0; + tmp.y = in_c4.y >= 0 ? MIN(in_c4.y, 6) : 0; + tmp.z = in_c4.z >= 0 ? MIN(in_c4.z, 6) : 0; + tmp.w = in_c4.w >= 0 ? MIN(in_c4.w, 6) : 0; + WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC + } +} + +__kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { + int C = input_shape.w; // channel size + int Y = get_global_id(0); // height id + int X = get_global_id(1); // weight id + for (int num = 0; num < UP_DIV(C, SLICES); ++num) { + FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC + FLT4 tmp; + tmp.x = 1 / (1 + exp(-in_c4.x)); + tmp.y = 1 / (1 + exp(-in_c4.y)); + tmp.z = 1 / (1 + exp(-in_c4.z)); + tmp.w = 1 / (1 + exp(-in_c4.w)); + WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl deleted file mode 100644 index 388f4c983a..0000000000 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl +++ /dev/null @@ -1,28 +0,0 @@ -#pragma OPENCL EXTENSION cl_arm_printf : enable - -#define SLICES 4 -#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -#define FLT4 float4 -#define READ_FLT4 read_imagef -#define WRITE_FLT4 write_imagef -__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - -__kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, - const float alpha) { - // int B = input_shape.x; // size - // int H = input_shape.y; // - // int W = input_shape.z; - int C = input_shape.w; - - int Y = get_global_id(0); // height id - int X = get_global_id(1); // weight id - for (int num = 0; num < UP_DIV(C, SLICES); ++num) { - FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC - FLT4 tmp; - tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha; - tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha; - tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha; - tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha; - WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC - } -} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc new file mode 100644 index 0000000000..49a99d1150 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "src/runtime/kernel/opencl/kernel/activation.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" +#include "src/ops/ops.h" +#include "src/runtime/kernel/opencl/cl/fp32/activation.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_LEAKY_RELU; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::ActivationType_SIGMOID; +using mindspore::schema::PrimitiveType_Activation; + +namespace mindspore::kernel { + +int ActivationOpenClKernel::Init() { + const int max_shape_dim = 4; + if (in_tensors_[0]->shape().size() != max_shape_dim) { + MS_LOG(ERROR) << "Activate fun only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); + return RET_ERROR; + } + std::string program_name = ""; + std::string kernel_name = ""; + std::string source = activation_source_fp32; + if (type_ == ActivationType_RELU) { + program_name = "RELU"; + kernel_name = "Relu"; + } else if (type_ == ActivationType_RELU6) { + program_name = "RELU6"; + kernel_name = "Relu6"; + } else if (type_ == ActivationType_LEAKY_RELU) { + program_name = "LEAKY_RELU"; + kernel_name = "ReluScalar"; + } else if (type_ == ActivationType_SIGMOID) { + program_name = "SIGMOID"; + kernel_name = "Sigmoid"; + } else { + MS_LOG(ERROR) << "Activation type error"; + return RET_ERROR; + } + std::set build_options; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + MS_LOG(DEBUG) << op_parameter_->name_ << " init Done!"; + return RET_OK; +} + +int ActivationOpenClKernel::Run() { + MS_LOG(DEBUG) << op_parameter_->name_ << " begin running!"; + int N = in_tensors_[0]->shape()[0]; + int H = in_tensors_[0]->shape()[1]; + int W = in_tensors_[0]->shape()[2]; + int C = in_tensors_[0]->shape()[3]; + cl_int4 input_shape = {N, H, W, C}; + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); + if (type_ == ActivationType_LEAKY_RELU) { + ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_); + } + std::vector local = {1, 1}; + std::vector global = {static_cast(H), static_cast(W)}; + std::cout << type_ << " " << std::endl; + auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run kernel:" << op_parameter_->name_ << " fail."; + return RET_ERROR; + } + return RET_OK; +} + +int ActivationOpenClKernel::GetImageSize(size_t idx, std::vector *img_size) { + int H = in_tensors_[0]->shape()[1]; + int W = in_tensors_[0]->shape()[2]; + int C = in_tensors_[0]->shape()[3]; + +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + + img_size->clear(); + img_size->push_back(W * UP_DIV(C, C4NUM)); + img_size->push_back(H); + img_size->push_back(img_dtype); + return RET_OK; +} + +kernel::LiteKernel *OpenClActivationFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { + if (inputs.size() == 0) { + MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); + return nullptr; + } + if (inputs[0]->shape()[0] > 1) { + MS_LOG(ERROR) << "Activation kernel:" << opParameter->name_ << " failed: Unsupported multi-batch."; + return nullptr; + } + auto *kernel = + new (std::nothrow) ActivationOpenClKernel(reinterpret_cast(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel:" << opParameter->name_ << "is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init activation kernel:" << opParameter->name_ << " failed!"; + delete kernel; + return nullptr; + } + return kernel; +} +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Activation, OpenClActivationFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h similarity index 54% rename from mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h rename to mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h index 935e1a1e90..910cf3f2a5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h @@ -14,24 +14,26 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ACTIVATION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ACTIVATION_H_ #include -#include -#include "src/ir/tensor.h" -#include "src/runtime/kernel/opencl/opencl_kernel.h" -#include "schema/model_generated.h" + #include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/activation.h" namespace mindspore::kernel { -class LeakyReluOpenCLKernel : public OpenCLKernel { +class ActivationOpenClKernel : public OpenCLKernel { public: - explicit LeakyReluOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} - ~LeakyReluOpenCLKernel() override{}; + explicit ActivationOpenClKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) { + type_ = (reinterpret_cast(parameter))->type_; + alpha_ = (reinterpret_cast(parameter))->alpha_; + } + ~ActivationOpenClKernel() override{}; int Init() override; int Run() override; @@ -39,8 +41,9 @@ class LeakyReluOpenCLKernel : public OpenCLKernel { private: cl::Kernel kernel_; + int type_; + float alpha_; }; } // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ACTIVATION_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 81b5da24e1..cca40e2db9 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -161,7 +161,8 @@ kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc, const lite::Primitive *primitive) { - auto *kernel = new ArithmeticOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); + auto *kernel = + new (std::nothrow) ArithmeticOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); if (kernel == nullptr) { MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 5835cb972a..4036397549 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -174,7 +174,8 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector(opParameter), inputs, outputs); + auto *kernel = + new (std::nothrow) Conv2dTransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 077d03c5db..76445d0744 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -193,7 +193,8 @@ kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector(opParameter), inputs, outputs); + auto *kernel = + new (std::nothrow) DepthwiseConv2dOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc deleted file mode 100644 index c804d0629c..0000000000 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "src/runtime/kernel/opencl/kernel/leaky_relu.h" -#include "src/runtime/opencl/opencl_runtime.h" -#include "src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl.inc" -#include "src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h" - -using mindspore::kernel::KERNEL_ARCH::kGPU; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_LeakyReLU; - -namespace mindspore::kernel { - -int LeakyReluOpenCLKernel::Init() { - if (in_tensors_[0]->shape().size() != 4) { - MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); - return RET_ERROR; - } - std::set build_options; - std::string source = leaky_relu_source_fp32; - std::string program_name = "LeakyRelu"; - std::string kernel_name = "LeakyRelu"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - ocl_runtime->LoadSource(program_name, source); - ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); - - MS_LOG(DEBUG) << kernel_name << " Init Done!"; - return RET_OK; -} - -int LeakyReluOpenCLKernel::Run() { - auto param = reinterpret_cast(op_parameter_); - MS_LOG(DEBUG) << " Running!"; - int N = in_tensors_[0]->shape()[0]; - int H = in_tensors_[0]->shape()[1]; - int W = in_tensors_[0]->shape()[2]; - int C = in_tensors_[0]->shape()[3]; - cl_int4 input_shape = {N, H, W, C}; - - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - int arg_idx = 0; - ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha); - - std::vector local = {1, 1}; - std::vector global = {static_cast(H), static_cast(W)}; - ocl_runtime->RunKernel(kernel_, global, local, nullptr); - return RET_OK; -} - -int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { - int H = in_tensors_[0]->shape()[1]; - int W = in_tensors_[0]->shape()[2]; - int C = in_tensors_[0]->shape()[3]; - -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else - size_t img_dtype = CL_FLOAT; -#endif - - img_size->clear(); - img_size->push_back(W * UP_DIV(C, C4NUM)); - img_size->push_back(H); - img_size->push_back(img_dtype); - return RET_OK; -} - -kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc, const lite::Primitive *primitive) { - if (inputs.size() == 0) { - MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); - return nullptr; - } - if (inputs[0]->shape()[0] > 1) { - MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch."; - return nullptr; - } - auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!"; - delete kernel; - return nullptr; - } - return kernel; -} - -REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index ca297b50cb..a6ae60ea72 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -160,7 +160,8 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vectortype_ == PrimitiveType_FullConnection) { hasBias = (reinterpret_cast(opParameter))->has_bias_; } - auto *kernel = new MatMulOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, hasBias); + auto *kernel = + new (std::nothrow) MatMulOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, hasBias); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc index 5d5c0cfb4d..184dd5cd9e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc @@ -145,7 +145,7 @@ kernel::LiteKernel *OpenCLPooling2dKernelCreator(const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc, const lite::Primitive *primitive) { - auto *kernel = new PoolingOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto *kernel = new (std::nothrow)PoolingOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "Create OpenCL Pooling kernel failed!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index 0148ae35fb..09c8d8b31b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -86,7 +86,7 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc, const lite::Primitive *primitive) { - auto *kernel = new SoftmaxOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto *kernel = new (std::nothrow)SoftmaxOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc index 4a5e59be31..b0052e0f5b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -109,7 +109,7 @@ kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc, const lite::Primitive *primitive) { - auto *kernel = new TransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto *kernel = new (std::nothrow)TransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; return nullptr; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index f37a10d18f..dbfbb37300 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -142,7 +142,7 @@ if (SUPPORT_GPU) ${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc - ${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc ) @@ -320,14 +320,14 @@ if (SUPPORT_GPU) ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc - ${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/activation_tests.cc ) endif() if (ENABLE_FP16) set(TEST_SRC ${TEST_SRC} - ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc) +) endif () diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc new file mode 100644 index 0000000000..9f79c0a877 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/opencl/opencl_allocator.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h" + +using mindspore::kernel::LiteKernel; +using mindspore::kernel::SubGraphOpenCLKernel; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_LEAKY_RELU; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::ActivationType_SIGMOID; +using mindspore::schema::PrimitiveType_Activation; + +namespace mindspore { +class TestActivationOpenCL : public mindspore::CommonTest {}; + +void LoadActivationData(void *dst, size_t dst_size, const std::string &file_path) { + if (file_path.empty()) { + memset(dst, 0x00, dst_size); + } else { + auto src_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); + memcpy(dst, src_data, dst_size); + } +} + +void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { + auto *output_data = reinterpret_cast(output_tensor->Data()); + size_t output_size = output_tensor->Size(); + auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); + constexpr float atol = 0.0002; + for (int i = 0; i < output_tensor->ElementsNum(); ++i) { + if (std::fabs(output_data[i] - expect_data[i]) > atol) { + printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); + return; + } + } + printf("compare success!\n"); + printf("compare success!\n"); + printf("compare success!\n\n\n"); +} + +void printf_tensor(mindspore::lite::tensor::Tensor *in_data) { + auto input_data = reinterpret_cast(in_data->Data()); + for (int i = 0; i < in_data->ElementsNum(); ++i) { + printf("%f ", input_data[i]); + } + printf("\n"); + MS_LOG(INFO) << "Print tensor done"; +} + +kernel::ActivationOpenClKernel *create_kernel(lite::opencl::OpenCLAllocator *allocator, + const std::vector &inputs, + const std::vector &outputs, std::string test_name, + int type, std::string in_file, float alpha = 0.2) { + auto *param = new (std::nothrow) ActivationParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "New ActivationParameter fail."; + return nullptr; + } + memcpy(param->op_parameter_.name_, test_name.c_str(), test_name.size()); + param->alpha_ = alpha; + param->type_ = type; + auto *kernel = + new (std::nothrow) kernel::ActivationOpenClKernel(reinterpret_cast(param), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Kernel:" << test_name << " create fail."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init " << test_name << " fail."; + return nullptr; + } + MS_LOG(INFO) << "Initialize input data"; + LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file); + MS_LOG(INFO) << "==================input data================"; + printf_tensor(inputs[0]); + return kernel; +} + +int RunSubGraphOpenCLKernel(const std::vector &inputs, + const std::vector &outputs, + kernel::ActivationOpenClKernel *kernel) { + MS_LOG(INFO) << "Create kernel SubGraphOpenCLKernel."; + std::vector kernels{kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(ERROR) << "Kernel SubGraphOpenCLKernel create fail."; + return RET_ERROR; + } + MS_LOG(INFO) << "Initialize sub_graph."; + auto ret = sub_graph->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init sub_graph error."; + return RET_ERROR; + } + MS_LOG(INFO) << "Run SubGraphOpenCLKernel."; + ret = sub_graph->Run(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error."; + return RET_ERROR; + } + return RET_OK; +} + +TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) { + MS_LOG(INFO) << "Begin test:"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "Init tensors."; + std::vector input_shape = {1, 4, 3, 8}; + + auto data_type = kNumberTypeFloat32; + auto tensor_type = schema::NodeType_ValueNode; + auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); + auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); + std::vector inputs{input_tensor}; + std::vector outputs{output_tensor}; + // freamework to do!!! allocate memory by hand + inputs[0]->MallocData(allocator); + + std::map Test_Activation_Type; + std::map Test_Res_File; + Test_Activation_Type["Relu"] = ActivationType_RELU; + Test_Activation_Type["Leaky_Relu"] = ActivationType_LEAKY_RELU; + Test_Activation_Type["Relu6"] = ActivationType_RELU6; + Test_Activation_Type["Sigmoid"] = ActivationType_SIGMOID; + Test_Res_File["Leaky_Relu"] = "/data/local/tmp/leaky_relu.bin"; + Test_Res_File["Relu"] = "/data/local/tmp/relu.bin"; + Test_Res_File["Relu6"] = "/data/local/tmp/relu6.bin"; + Test_Res_File["Sigmoid"] = "/data/local/tmp/sigmoid.bin"; + std::string in_file = "/data/local/tmp/in_data.bin"; + + std::map::iterator it = Test_Activation_Type.begin(); + while (it != Test_Activation_Type.end()) { + auto kernel = create_kernel(allocator, inputs, outputs, it->first, it->second, in_file, 0.3); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel:" << it->first << " error."; + return; + } + + auto ret = RunSubGraphOpenCLKernel(inputs, outputs, kernel); + if (ret != RET_OK) { + MS_LOG(ERROR) << it->first << " RunSubGraphOpenCLKernel error."; + return; + } + MS_LOG(INFO) << "==================output data================"; + printf_tensor(outputs[0]); + CompareRes(output_tensor, Test_Res_File[it->first]); + delete kernel; + it++; + } + + delete input_tensor; + delete output_tensor; + return; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc deleted file mode 100644 index 8123138272..0000000000 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "utils/log_adapter.h" -#include "common/common_test.h" -#include "mindspore/lite/src/common/file_utils.h" -#include "src/runtime/kernel/arm/nnacl/pack.h" -#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" -#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" -#include "mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h" -#include "mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h" - -using mindspore::kernel::LeakyReluOpenCLKernel; -using mindspore::kernel::LiteKernel; -using mindspore::kernel::SubGraphOpenCLKernel; - -namespace mindspore { -class TestLeakyReluOpenCL : public mindspore::CommonTest {}; - -void LoadDataLeakyRelu(void *dst, size_t dst_size, const std::string &file_path) { - if (file_path.empty()) { - memset(dst, 0x00, dst_size); - } else { - auto src_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); - memcpy(dst, src_data, dst_size); - } -} - -void CompareOutLeakyRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { - auto *output_data = reinterpret_cast(output_tensor->Data()); - size_t output_size = output_tensor->Size(); - auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); - constexpr float atol = 0.0002; - for (int i = 0; i < output_tensor->ElementsNum(); ++i) { - if (std::fabs(output_data[i] - expect_data[i]) > atol) { - printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); - return; - } - } - printf("compare success!\n"); - printf("compare success!\n"); - printf("compare success!\n\n\n"); -} - -void printf_tensor(mindspore::lite::tensor::Tensor *in_data) { - auto input_data = reinterpret_cast(in_data->Data()); - for (int i = 0; i < in_data->ElementsNum(); ++i) { - printf("%f ", input_data[i]); - } - printf("\n"); - MS_LOG(INFO) << "Print tensor done"; -} - -TEST_F(TestLeakyReluOpenCL, LeakyReluFp32_dim4) { - std::string in_file = "/data/local/tmp/in_data.bin"; - std::string standard_answer_file = "/data/local/tmp/out_data.bin"; - MS_LOG(INFO) << "Begin test:"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - ocl_runtime->Init(); - auto allocator = ocl_runtime->GetAllocator(); - - MS_LOG(INFO) << "Init tensors."; - std::vector input_shape = {1, 4, 3, 8}; - - auto data_type = kNumberTypeFloat32; - auto tensor_type = schema::NodeType_ValueNode; - auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); - auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); - std::vector inputs{input_tensor}; - std::vector outputs{output_tensor}; - - // freamework to do!!! allocate memory by hand - inputs[0]->MallocData(allocator); - - auto param = new LeakyReluParameter(); - param->alpha = 0.3; - auto *leakyrelu_kernel = new kernel::LeakyReluOpenCLKernel(reinterpret_cast(param), inputs, outputs); - leakyrelu_kernel->Init(); - - MS_LOG(INFO) << "initialize sub_graph"; - std::vector kernels{leakyrelu_kernel}; - auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); - sub_graph->Init(); - - MS_LOG(INFO) << "initialize input data"; - LoadDataLeakyRelu(input_tensor->Data(), input_tensor->Size(), in_file); - MS_LOG(INFO) << "==================input data================"; - printf_tensor(inputs[0]); - sub_graph->Run(); - - MS_LOG(INFO) << "==================output data================"; - printf_tensor(outputs[0]); - CompareOutLeakyRelu(output_tensor, standard_answer_file); -} -} // namespace mindspore