diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl new file mode 100644 index 0000000000..e2e1f5b0bb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl @@ -0,0 +1,31 @@ +#pragma OPENCL EXTENSION cl_arm_printf : enable + +#define SLICES 4 +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define FLT4 float4 +#define READ_FLT4 read_imagef +#define WRITE_FLT4 write_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, + __global float *alpha, const int dim) { + int C = input_shape.w; // channel size + + int Y = get_global_id(0); // height id + int X = get_global_id(1); // weight id + for (int num = 0; num < UP_DIV(C, SLICES); ++num) { + FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC + FLT4 tmp; + int index = 0; + if (dim == 2) { + index = X * 4; + } else { + index = num * 4; + } + tmp.x = in_c4.x + alpha[index]; + tmp.y = in_c4.y + alpha[index + 1]; + tmp.z = in_c4.z + alpha[index + 2]; + tmp.w = in_c4.w + alpha[index + 3]; + WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc new file mode 100644 index 0000000000..151412194b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc @@ -0,0 +1,167 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/opencl/kernel/biasadd.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/cl/biasadd.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BiasAdd; + +namespace mindspore::kernel { + +void BiasAddOpenCLKernel::InitBuffer() { + int C = in_tensors_[1]->shape()[0]; + int div_ci = UP_DIV(C, C4NUM); + auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); + BiasAdd_ = reinterpret_cast(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t))); + BiasAdd_ = reinterpret_cast(allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true)); + memset(BiasAdd_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t)); + auto origin_weight = reinterpret_cast(in_tensors_[1]->Data()); + for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) { + BiasAdd_[i] = origin_weight[i]; + } + allocator->UnmapBuffer(BiasAdd_); +} + +int BiasAddOpenCLKernel::Init() { + in_size_ = in_tensors_[0]->shape().size(); + out_size_ = out_tensors_[0]->shape().size(); + if (in_size_ != 4 && in_size_ != 2) { + MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_; + return RET_ERROR; + } + int C = in_tensors_[0]->shape()[3]; + int Bias_Size = in_tensors_[1]->shape()[0]; + if (UP_DIV(Bias_Size, C4NUM) != UP_DIV(C, C4NUM)) { + MS_LOG(ERROR) << "BiasAdd weight channel size:" << Bias_Size << " must be equal with in_teneors channel size:" << C; + return RET_ERROR; + } + InitBuffer(); + std::set build_options; + std::string source = biasadd_source; + std::string program_name = "BiasAdd"; + std::string kernel_name = "BiasAdd"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + + in_ori_format_ = in_tensors_[0]->GetFormat(); + out_ori_format_ = out_tensors_[0]->GetFormat(); + std::map format{{4, schema::Format_NHWC4}, {2, schema::Format_NC4}}; + if (format.count(out_size_) == 0) { + MS_LOG(ERROR) << "Not found output tensor format"; + return RET_ERROR; + } + in_tensors_[0]->SetFormat(format[in_size_]); + out_tensors_[0]->SetFormat(format[out_size_]); + if (in_size_ == 2) { + in_ori_format_ = format[in_size_]; + out_ori_format_ = format[out_size_]; + } + MS_LOG(DEBUG) << program_name << " Init Done!"; + return RET_OK; +} + +int BiasAddOpenCLKernel::Run() { + cl_int4 input_shape = GetImg2dShape(); + MS_LOG(DEBUG) << op_parameter_->name_ << " Running!"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, BiasAdd_); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_size_); + std::vector local = {1, 1}; + std::vector global = {static_cast(input_shape.s[1]), static_cast(input_shape.s[2])}; + auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error."; + return RET_ERROR; + } + return RET_OK; +} + +cl_int4 BiasAddOpenCLKernel::GetImg2dShape() { + cl_int4 img2d_shape = {0, 0, 0, 0}; + for (int i = 0; i < in_size_; ++i) { + img2d_shape.s[i + 4 - in_size_] = in_tensors_[0]->shape()[i]; + } + if (in_size_ == 2) { + img2d_shape.s[1] = img2d_shape.s[2]; + img2d_shape.s[2] = UP_DIV(img2d_shape.s[3], C4NUM); + img2d_shape.s[3] = C4NUM; + } + return img2d_shape; +} + +int BiasAddOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + cl_int4 img_shape = GetImg2dShape(); +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + + img_size->clear(); + img_size->push_back(img_shape.s[2] * UP_DIV(img_shape.s[3], C4NUM)); + img_size->push_back(img_shape.s[1]); + img_size->push_back(img_dtype); + return RET_OK; +} + +kernel::LiteKernel *OpenCLBiasAddKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) { + if (inputs.size() == 0) { + MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); + return nullptr; + } + if (inputs[0]->shape()[0] > 1) { + MS_LOG(ERROR) << "Input data size unsupported multi-batch."; + return nullptr; + } + auto *kernel = new (std::nothrow) BiasAddOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Kernel " << opParameter->name_ << "is nullptr."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init BiasAdd kernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h new file mode 100644 index 0000000000..56535afddb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ + +#include +#include + +#include "src/ir/tensor.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "schema/model_generated.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::kernel { + +class BiasAddOpenCLKernel : public OpenCLKernel { + public: + explicit BiasAddOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + ~BiasAddOpenCLKernel() override{}; + + int Init() override; + int Run() override; + int GetImageSize(size_t idx, std::vector *img_size) override; + void InitBuffer(); + cl_int4 GetImg2dShape(); + + private: + cl::Kernel kernel_; + FLOAT_t *BiasAdd_; + int in_size_; + int out_size_; +}; + +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 7b1e91a602..6767810eb8 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -158,6 +158,7 @@ if (SUPPORT_GPU) ${LITE_DIR}/src/runtime/kernel/opencl/kernel/caffe_prelu.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/prelu.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/biasadd.cc ) endif() ### minddata lite @@ -338,6 +339,7 @@ if (SUPPORT_GPU) ${TEST_DIR}/ut/src/runtime/kernel/opencl/caffe_prelu_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/prelu_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/reshape_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/biasadd_tests.cc ) endif() diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc new file mode 100644 index 0000000000..cd939cbc47 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc @@ -0,0 +1,202 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h" + +using mindspore::kernel::BiasAddOpenCLKernel; +using mindspore::kernel::LiteKernel; +using mindspore::kernel::SubGraphOpenCLKernel; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore { +class TestBiasAddOpenCL : public mindspore::CommonTest {}; + +void LoadDataBiasAdd(void *dst, size_t dst_size, const std::string &file_path) { + if (file_path.empty()) { + memset(dst, 0x00, dst_size); + } else { + auto src_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); + memcpy(dst, src_data, dst_size); + } +} + +void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { + auto *output_data = reinterpret_cast(output_tensor->Data()); + size_t output_size = output_tensor->ElementsNum(); + auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); + constexpr float atol = 0.0002; + for (int i = 0; i < output_tensor->ElementsNum(); ++i) { + if (std::fabs(output_data[i] - expect_data[i]) > atol) { + printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); + return; + } + } + printf("compare success!\n"); + printf("compare success!\n"); + printf("compare success!\n\n\n"); +} + +void printf_tensor_BiasAdd(mindspore::lite::tensor::Tensor *in_data, int size) { + auto input_data = reinterpret_cast(in_data->Data()); + for (int i = 0; i < size; ++i) { + printf("%f ", input_data[i]); + } + printf("\n"); + MS_LOG(INFO) << "Print tensor done"; +} + +void printf_float_BiasAdd(float *data, int num = 0) { + float *temp = data; + for (int i = 0; i < num; ++i) { + std::cout << *temp << " "; + temp++; + } + std::cout << std::endl; +} + +TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { + std::string in_file = "/data/local/tmp/in_data.bin"; + std::string weight_file = "/data/local/tmp/weight_data.bin"; + std::string standard_answer_file = "/data/local/tmp/biasadd.bin"; + MS_LOG(INFO) << "BiasAdd Begin test:"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "BiasAdd init tensors."; + + std::vector input_shape = {1, 9}; + std::vector output_shape = {1, 9}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = schema::NodeType_ValueNode; + auto *input_tensor = + new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + if (input_tensor == nullptr) { + MS_LOG(ERROR) << "new input tensor error!"; + return; + } + auto *output_tensor = + new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NC, tensor_type); + if (output_tensor == nullptr) { + MS_LOG(ERROR) << "new output tensor error!"; + delete input_tensor; + return; + } + auto *weight_tensor = new (std::nothrow) + lite::tensor::Tensor(data_type, std::vector{input_shape[1]}, schema::Format_NHWC, tensor_type); + if (weight_tensor == nullptr) { + MS_LOG(ERROR) << "new weight tensor error!"; + delete output_tensor; + delete input_tensor; + return; + } + std::vector inputs{input_tensor, weight_tensor}; + std::vector outputs{output_tensor}; + inputs[0]->MallocData(allocator); + inputs[1]->MallocData(allocator); + LoadDataBiasAdd(input_tensor->Data(), input_tensor->Size(), in_file); + MS_LOG(INFO) << "BiasAdd==================input data================"; + printf_tensor_BiasAdd(inputs[0], input_tensor->ElementsNum()); + LoadDataBiasAdd(weight_tensor->Data(), weight_tensor->Size(), weight_file); + MS_LOG(INFO) << "BiasAdd==================weight data================"; + printf_tensor_BiasAdd(inputs[1], weight_tensor->ElementsNum()); + + auto *param = new (std::nothrow) OpParameter(); + if (param == nullptr) { + delete input_tensor; + delete output_tensor; + delete weight_tensor; + MS_LOG(ERROR) << "new OpParameter error!"; + return; + } + auto *biasadd_kernel = + new (std::nothrow) kernel::BiasAddOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (biasadd_kernel == nullptr) { + MS_LOG(ERROR) << "Create biasadd kernel error."; + delete input_tensor; + delete output_tensor; + delete weight_tensor; + delete param; + return; + } + + auto ret = biasadd_kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "biasadd kernel init error."; + delete input_tensor; + delete output_tensor; + delete weight_tensor; + delete param; + delete biasadd_kernel; + return; + } + + MS_LOG(INFO) << "initialize sub_graph"; + std::vector kernels{biasadd_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({input_tensor}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(ERROR) << "Create sub_graph kernel error."; + delete input_tensor; + delete output_tensor; + delete weight_tensor; + delete param; + delete biasadd_kernel; + return; + } + ret = sub_graph->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "sub_graph init error."; + delete input_tensor; + delete output_tensor; + delete weight_tensor; + delete sub_graph; + delete param; + delete biasadd_kernel; + return; + } + MS_LOG(INFO) << "Sub graph begin running!"; + ret = sub_graph->Run(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "sub_graph run error."; + delete input_tensor; + delete output_tensor; + delete weight_tensor; + delete sub_graph; + delete param; + delete biasadd_kernel; + return; + } + + MS_LOG(INFO) << "BiasAdd==================output data================"; + printf_tensor_BiasAdd(outputs[0], output_tensor->ElementsNum()); + CompareOutBiasAdd(output_tensor, standard_answer_file); + delete input_tensor; + delete weight_tensor; + delete output_tensor; + delete sub_graph; + delete param; + delete biasadd_kernel; +} +} // namespace mindspore