From 2eb55946f672625da4c79de18d89364499a00247 Mon Sep 17 00:00:00 2001 From: liuzhongkai Date: Tue, 11 Aug 2020 23:44:10 -0700 Subject: [PATCH] add opencl leaky relu kernel --- .../kernel/arm/nnacl/leaky_relu_parameter.h | 27 ++++ .../kernel/opencl/cl/fp32/leaky_relu.cl | 1 - .../kernel/opencl/kernel/conv2d_transpose.cc | 4 + .../kernel/opencl/kernel/depthwise_conv2d.cc | 4 + .../kernel/opencl/kernel/leaky_relu.cc | 148 +++++++++--------- .../runtime/kernel/opencl/kernel/leaky_relu.h | 17 +- .../runtime/kernel/opencl/kernel/matmul.cc | 4 + .../runtime/kernel/opencl/kernel/softmax.cc | 4 + .../runtime/kernel/opencl/kernel/transpose.cc | 4 + mindspore/lite/test/CMakeLists.txt | 4 +- .../runtime/kernel/opencl/leakyrelu_tests.cc | 11 +- 11 files changed, 140 insertions(+), 88 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h new file mode 100644 index 0000000000..e1b9eb0fd5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LEAKYRELU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LEAKYRELU_H_ + +#include "nnacl/op_base.h" + +typedef struct LeakyReluParameter { + OpParameter op_parameter_; + float alpha; +} LeakyReluParameter; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LEAKYRELU_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl index 0330b8590a..388f4c983a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl @@ -16,7 +16,6 @@ __kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t outp int Y = get_global_id(0); // height id int X = get_global_id(1); // weight id - for (int num = 0; num < UP_DIV(C, SLICES); ++num) { FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 tmp; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 3a67a66ac5..5835cb972a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -175,6 +175,10 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; + return nullptr; + } auto ret = kernel->Init(); if (0 != ret) { // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 401a95afe8..077d03c5db 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -194,6 +194,10 @@ kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; + return nullptr; + } auto ret = kernel->Init(); if (0 != ret) { MS_LOG(ERROR) << "Init DepthwiseConv2dOpenCLKernel failed!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc index d0c630e973..c804d0629c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc @@ -18,97 +18,105 @@ #include #include +#include #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/kernel/opencl/kernel/leaky_relu.h" #include "src/runtime/opencl/opencl_runtime.h" #include "src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl.inc" +#include "src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_LeakyReLU; namespace mindspore::kernel { - int LeakyReluOpenCLKernel::Init() { - if (inputs_[0]->shape().size() != 4) { - MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << inputs_[0]->shape().size(); - } - std::set build_options; - std::string source = leaky_relu_source_fp32; - std::string program_name = "LeakyRelu"; - std::string kernel_name = "LeakyRelu"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - ocl_runtime->LoadSource(program_name, source); - ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); - - MS_LOG(DEBUG) << kernel_name << " Init Done!"; - return RET_OK; +int LeakyReluOpenCLKernel::Init() { + if (in_tensors_[0]->shape().size() != 4) { + MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); + return RET_ERROR; } - - - int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { - int H = inputs_[0]->shape()[1]; - int W = inputs_[0]->shape()[2]; - int C = inputs_[0]->shape()[3]; + std::set build_options; + std::string source = leaky_relu_source_fp32; + std::string program_name = "LeakyRelu"; + std::string kernel_name = "LeakyRelu"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return RET_OK; +} + +int LeakyReluOpenCLKernel::Run() { + auto param = reinterpret_cast(op_parameter_); + MS_LOG(DEBUG) << " Running!"; + int N = in_tensors_[0]->shape()[0]; + int H = in_tensors_[0]->shape()[1]; + int W = in_tensors_[0]->shape()[2]; + int C = in_tensors_[0]->shape()[3]; + cl_int4 input_shape = {N, H, W, C}; + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha); + + std::vector local = {1, 1}; + std::vector global = {static_cast(H), static_cast(W)}; + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return RET_OK; +} + +int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + int H = in_tensors_[0]->shape()[1]; + int W = in_tensors_[0]->shape()[2]; + int C = in_tensors_[0]->shape()[3]; #ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; + size_t img_dtype = CL_HALF_FLOAT; #else - size_t img_dtype = CL_FLOAT; + size_t img_dtype = CL_FLOAT; #endif - img_size->clear(); - img_size->push_back(W * UP_DIV(C, C4NUM)); - img_size->push_back(H); - img_size->push_back(img_dtype); - return RET_OK; - } - - int LeakyReluOpenCLKernel::Run() { - auto param = reinterpret_cast(this->opParameter); - MS_LOG(DEBUG) << this->Name() << " Running!"; - int N = inputs_[0]->shape()[0]; - int H = inputs_[0]->shape()[1]; - int W = inputs_[0]->shape()[2]; - int C = inputs_[0]->shape()[3]; - cl_int4 input_shape = {N, H, W, C}; - - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - int arg_idx = 0; - ocl_runtime->SetKernelArg(kernel_, arg_idx++, inputs_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, outputs_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha); - - std::vector local = {1, 1}; - std::vector global = {static_cast(H), static_cast(W)}; - ocl_runtime->RunKernel(kernel_, global, local, nullptr); - return 0; + img_size->clear(); + img_size->push_back(W * UP_DIV(C, C4NUM)); + img_size->push_back(H); + img_size->push_back(img_dtype); + return RET_OK; +} + +kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { + if (inputs.size() == 0) { + MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); + return nullptr; } - - kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc, const lite::Primitive *primitive) { - auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); - if (inputs.size() == 0) { - MS_LOG(ERROR) << "Input data size must must be greater than 0, but your size is " << inputs.size(); - } - if (inputs[0]->shape()[0] > 1) { - MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch."; - } - auto ret = kernel->Init(); - if (0 != ret) { - MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!"; - delete kernel; - return nullptr; - } - return kernel; + if (inputs[0]->shape()[0] > 1) { + MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch."; + return nullptr; + } + auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; + return nullptr; } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} - REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator) } // namespace mindspore::kernel - diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h index 8ad56bba11..935e1a1e90 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h @@ -14,18 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_ -#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_ #include - -#include "src/runtime/opencl/opencl_runtime.h" +#include +#include "src/ir/tensor.h" #include "src/runtime/kernel/opencl/opencl_kernel.h" - -struct LeakyReluParameter { - OpParameter op_parameter_; - cl_float alpha; -}; +#include "schema/model_generated.h" +#include "src/runtime/opencl/opencl_runtime.h" namespace mindspore::kernel { @@ -46,4 +43,4 @@ class LeakyReluOpenCLKernel : public OpenCLKernel { } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index 674e3c0b86..ca297b50cb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -161,6 +161,10 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector(opParameter))->has_bias_; } auto *kernel = new MatMulOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, hasBias); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; + return nullptr; + } auto ret = kernel->Init(); if (0 != ret) { // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index f8389b3449..0148ae35fb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -87,6 +87,10 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; + return nullptr; + } if (inputs[0]->shape()[0] > 1) { MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported multi-batch."; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc index 246597e917..4a5e59be31 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -110,6 +110,10 @@ kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; + return nullptr; + } auto ret = kernel->Init(); if (0 != ret) { delete kernel; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 548dd99f9f..f37a10d18f 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -142,7 +142,7 @@ if (SUPPORT_GPU) ${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc - # ${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc ) @@ -320,7 +320,7 @@ if (SUPPORT_GPU) ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc - # ${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc ) endif() diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc index 8fa9a75d51..8123138272 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc @@ -21,12 +21,14 @@ #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h" + +using mindspore::kernel::LeakyReluOpenCLKernel; +using mindspore::kernel::LiteKernel; +using mindspore::kernel::SubGraphOpenCLKernel; namespace mindspore { -class TestLeakyReluOpenCL : public mindspore::Common { - public: - TestLeakyReluOpenCL() {} -}; +class TestLeakyReluOpenCL : public mindspore::CommonTest {}; void LoadDataLeakyRelu(void *dst, size_t dst_size, const std::string &file_path) { if (file_path.empty()) { @@ -99,7 +101,6 @@ TEST_F(TestLeakyReluOpenCL, LeakyReluFp32_dim4) { LoadDataLeakyRelu(input_tensor->Data(), input_tensor->Size(), in_file); MS_LOG(INFO) << "==================input data================"; printf_tensor(inputs[0]); - sub_graph->Run(); MS_LOG(INFO) << "==================output data================";