parent
2eb55946f6
commit
56d32b4e77
@ -0,0 +1,70 @@
|
||||
#pragma OPENCL EXTENSION cl_arm_printf : enable
|
||||
|
||||
#define SLICES 4
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
#define FLT4 float4
|
||||
#define MIN(X, Y) (X < Y ? X : Y)
|
||||
#define READ_FLT4 read_imagef
|
||||
#define WRITE_FLT4 write_imagef
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void ReluScalar(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
|
||||
const float alpha) {
|
||||
int C = input_shape.w; // channel size
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha;
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha;
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha;
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha;
|
||||
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void Relu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) {
|
||||
int C = input_shape.w; // channel size
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : 0;
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : 0;
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : 0;
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : 0;
|
||||
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) {
|
||||
int C = input_shape.w; // channel size
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = in_c4.x >= 0 ? MIN(in_c4.x, 6) : 0;
|
||||
tmp.y = in_c4.y >= 0 ? MIN(in_c4.y, 6) : 0;
|
||||
tmp.z = in_c4.z >= 0 ? MIN(in_c4.z, 6) : 0;
|
||||
tmp.w = in_c4.w >= 0 ? MIN(in_c4.w, 6) : 0;
|
||||
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) {
|
||||
int C = input_shape.w; // channel size
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = 1 / (1 + exp(-in_c4.x));
|
||||
tmp.y = 1 / (1 + exp(-in_c4.y));
|
||||
tmp.z = 1 / (1 + exp(-in_c4.z));
|
||||
tmp.w = 1 / (1 + exp(-in_c4.w));
|
||||
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
@ -1,28 +0,0 @@
|
||||
#pragma OPENCL EXTENSION cl_arm_printf : enable
|
||||
|
||||
#define SLICES 4
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
#define FLT4 float4
|
||||
#define READ_FLT4 read_imagef
|
||||
#define WRITE_FLT4 write_imagef
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
|
||||
const float alpha) {
|
||||
// int B = input_shape.x; // size
|
||||
// int H = input_shape.y; //
|
||||
// int W = input_shape.z;
|
||||
int C = input_shape.w;
|
||||
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha;
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha;
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha;
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha;
|
||||
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
@ -0,0 +1,146 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
#include "src/runtime/kernel/opencl/kernel/activation.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/ops/ops.h"
|
||||
#include "src/runtime/kernel/opencl/cl/fp32/activation.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::ActivationType_LEAKY_RELU;
|
||||
using mindspore::schema::ActivationType_RELU;
|
||||
using mindspore::schema::ActivationType_RELU6;
|
||||
using mindspore::schema::ActivationType_SIGMOID;
|
||||
using mindspore::schema::PrimitiveType_Activation;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int ActivationOpenClKernel::Init() {
|
||||
const int max_shape_dim = 4;
|
||||
if (in_tensors_[0]->shape().size() != max_shape_dim) {
|
||||
MS_LOG(ERROR) << "Activate fun only support dim=4, but your dim=" << in_tensors_[0]->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::string program_name = "";
|
||||
std::string kernel_name = "";
|
||||
std::string source = activation_source_fp32;
|
||||
if (type_ == ActivationType_RELU) {
|
||||
program_name = "RELU";
|
||||
kernel_name = "Relu";
|
||||
} else if (type_ == ActivationType_RELU6) {
|
||||
program_name = "RELU6";
|
||||
kernel_name = "Relu6";
|
||||
} else if (type_ == ActivationType_LEAKY_RELU) {
|
||||
program_name = "LEAKY_RELU";
|
||||
kernel_name = "ReluScalar";
|
||||
} else if (type_ == ActivationType_SIGMOID) {
|
||||
program_name = "SIGMOID";
|
||||
kernel_name = "Sigmoid";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::set<std::string> build_options;
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
MS_LOG(DEBUG) << op_parameter_->name_ << " init Done!";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ActivationOpenClKernel::Run() {
|
||||
MS_LOG(DEBUG) << op_parameter_->name_ << " begin running!";
|
||||
int N = in_tensors_[0]->shape()[0];
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
cl_int4 input_shape = {N, H, W, C};
|
||||
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
int arg_idx = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
|
||||
if (type_ == ActivationType_LEAKY_RELU) {
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_);
|
||||
}
|
||||
std::vector<size_t> local = {1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
|
||||
std::cout << type_ << " " << std::endl;
|
||||
auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run kernel:" << op_parameter_->name_ << " fail.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ActivationOpenClKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
#else
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
#endif
|
||||
|
||||
img_size->clear();
|
||||
img_size->push_back(W * UP_DIV(C, C4NUM));
|
||||
img_size->push_back(H);
|
||||
img_size->push_back(img_dtype);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *OpenClActivationFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
|
||||
return nullptr;
|
||||
}
|
||||
if (inputs[0]->shape()[0] > 1) {
|
||||
MS_LOG(ERROR) << "Activation kernel:" << opParameter->name_ << " failed: Unsupported multi-batch.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel =
|
||||
new (std::nothrow) ActivationOpenClKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "New kernel:" << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init activation kernel:" << opParameter->name_ << " failed!";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Activation, OpenClActivationFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -1,122 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
*
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/leaky_relu.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl.inc"
|
||||
#include "src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_LeakyReLU;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int LeakyReluOpenCLKernel::Init() {
|
||||
if (in_tensors_[0]->shape().size() != 4) {
|
||||
MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << in_tensors_[0]->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::set<std::string> build_options;
|
||||
std::string source = leaky_relu_source_fp32;
|
||||
std::string program_name = "LeakyRelu";
|
||||
std::string kernel_name = "LeakyRelu";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LeakyReluOpenCLKernel::Run() {
|
||||
auto param = reinterpret_cast<LeakyReluParameter *>(op_parameter_);
|
||||
MS_LOG(DEBUG) << " Running!";
|
||||
int N = in_tensors_[0]->shape()[0];
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
cl_int4 input_shape = {N, H, W, C};
|
||||
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
int arg_idx = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha);
|
||||
|
||||
std::vector<size_t> local = {1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
|
||||
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
int H = in_tensors_[0]->shape()[1];
|
||||
int W = in_tensors_[0]->shape()[2];
|
||||
int C = in_tensors_[0]->shape()[3];
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
#else
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
#endif
|
||||
|
||||
img_size->clear();
|
||||
img_size->push_back(W * UP_DIV(C, C4NUM));
|
||||
img_size->push_back(H);
|
||||
img_size->push_back(img_dtype);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
|
||||
return nullptr;
|
||||
}
|
||||
if (inputs[0]->shape()[0] > 1) {
|
||||
MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,185 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_allocator.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h"
|
||||
|
||||
using mindspore::kernel::LiteKernel;
|
||||
using mindspore::kernel::SubGraphOpenCLKernel;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::ActivationType_LEAKY_RELU;
|
||||
using mindspore::schema::ActivationType_RELU;
|
||||
using mindspore::schema::ActivationType_RELU6;
|
||||
using mindspore::schema::ActivationType_SIGMOID;
|
||||
using mindspore::schema::PrimitiveType_Activation;
|
||||
|
||||
namespace mindspore {
|
||||
class TestActivationOpenCL : public mindspore::CommonTest {};
|
||||
|
||||
void LoadActivationData(void *dst, size_t dst_size, const std::string &file_path) {
|
||||
if (file_path.empty()) {
|
||||
memset(dst, 0x00, dst_size);
|
||||
} else {
|
||||
auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size));
|
||||
memcpy(dst, src_data, dst_size);
|
||||
}
|
||||
}
|
||||
|
||||
void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
|
||||
auto *output_data = reinterpret_cast<float *>(output_tensor->Data());
|
||||
size_t output_size = output_tensor->Size();
|
||||
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
|
||||
constexpr float atol = 0.0002;
|
||||
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
|
||||
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]);
|
||||
return;
|
||||
}
|
||||
}
|
||||
printf("compare success!\n");
|
||||
printf("compare success!\n");
|
||||
printf("compare success!\n\n\n");
|
||||
}
|
||||
|
||||
void printf_tensor(mindspore::lite::tensor::Tensor *in_data) {
|
||||
auto input_data = reinterpret_cast<float *>(in_data->Data());
|
||||
for (int i = 0; i < in_data->ElementsNum(); ++i) {
|
||||
printf("%f ", input_data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
MS_LOG(INFO) << "Print tensor done";
|
||||
}
|
||||
|
||||
kernel::ActivationOpenClKernel *create_kernel(lite::opencl::OpenCLAllocator *allocator,
|
||||
const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, std::string test_name,
|
||||
int type, std::string in_file, float alpha = 0.2) {
|
||||
auto *param = new (std::nothrow) ActivationParameter();
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "New ActivationParameter fail.";
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(param->op_parameter_.name_, test_name.c_str(), test_name.size());
|
||||
param->alpha_ = alpha;
|
||||
param->type_ = type;
|
||||
auto *kernel =
|
||||
new (std::nothrow) kernel::ActivationOpenClKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Kernel:" << test_name << " create fail.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init " << test_name << " fail.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Initialize input data";
|
||||
LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file);
|
||||
MS_LOG(INFO) << "==================input data================";
|
||||
printf_tensor(inputs[0]);
|
||||
return kernel;
|
||||
}
|
||||
|
||||
int RunSubGraphOpenCLKernel(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
kernel::ActivationOpenClKernel *kernel) {
|
||||
MS_LOG(INFO) << "Create kernel SubGraphOpenCLKernel.";
|
||||
std::vector<kernel::LiteKernel *> kernels{kernel};
|
||||
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
if (sub_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Kernel SubGraphOpenCLKernel create fail.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "Initialize sub_graph.";
|
||||
auto ret = sub_graph->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init sub_graph error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "Run SubGraphOpenCLKernel.";
|
||||
ret = sub_graph->Run();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) {
|
||||
MS_LOG(INFO) << "Begin test:";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 4, 3, 8};
|
||||
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
|
||||
auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
|
||||
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
// freamework to do!!! allocate memory by hand
|
||||
inputs[0]->MallocData(allocator);
|
||||
|
||||
std::map<std::string, int> Test_Activation_Type;
|
||||
std::map<std::string, std::string> Test_Res_File;
|
||||
Test_Activation_Type["Relu"] = ActivationType_RELU;
|
||||
Test_Activation_Type["Leaky_Relu"] = ActivationType_LEAKY_RELU;
|
||||
Test_Activation_Type["Relu6"] = ActivationType_RELU6;
|
||||
Test_Activation_Type["Sigmoid"] = ActivationType_SIGMOID;
|
||||
Test_Res_File["Leaky_Relu"] = "/data/local/tmp/leaky_relu.bin";
|
||||
Test_Res_File["Relu"] = "/data/local/tmp/relu.bin";
|
||||
Test_Res_File["Relu6"] = "/data/local/tmp/relu6.bin";
|
||||
Test_Res_File["Sigmoid"] = "/data/local/tmp/sigmoid.bin";
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
|
||||
std::map<std::string, int>::iterator it = Test_Activation_Type.begin();
|
||||
while (it != Test_Activation_Type.end()) {
|
||||
auto kernel = create_kernel(allocator, inputs, outputs, it->first, it->second, in_file, 0.3);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Create kernel:" << it->first << " error.";
|
||||
return;
|
||||
}
|
||||
|
||||
auto ret = RunSubGraphOpenCLKernel(inputs, outputs, kernel);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << it->first << " RunSubGraphOpenCLKernel error.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "==================output data================";
|
||||
printf_tensor(outputs[0]);
|
||||
CompareRes(output_tensor, Test_Res_File[it->first]);
|
||||
delete kernel;
|
||||
it++;
|
||||
}
|
||||
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
return;
|
||||
}
|
||||
} // namespace mindspore
|
@ -1,110 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/pack.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h"
|
||||
|
||||
using mindspore::kernel::LeakyReluOpenCLKernel;
|
||||
using mindspore::kernel::LiteKernel;
|
||||
using mindspore::kernel::SubGraphOpenCLKernel;
|
||||
|
||||
namespace mindspore {
|
||||
class TestLeakyReluOpenCL : public mindspore::CommonTest {};
|
||||
|
||||
void LoadDataLeakyRelu(void *dst, size_t dst_size, const std::string &file_path) {
|
||||
if (file_path.empty()) {
|
||||
memset(dst, 0x00, dst_size);
|
||||
} else {
|
||||
auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size));
|
||||
memcpy(dst, src_data, dst_size);
|
||||
}
|
||||
}
|
||||
|
||||
void CompareOutLeakyRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
|
||||
auto *output_data = reinterpret_cast<float *>(output_tensor->Data());
|
||||
size_t output_size = output_tensor->Size();
|
||||
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
|
||||
constexpr float atol = 0.0002;
|
||||
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
|
||||
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]);
|
||||
return;
|
||||
}
|
||||
}
|
||||
printf("compare success!\n");
|
||||
printf("compare success!\n");
|
||||
printf("compare success!\n\n\n");
|
||||
}
|
||||
|
||||
void printf_tensor(mindspore::lite::tensor::Tensor *in_data) {
|
||||
auto input_data = reinterpret_cast<float *>(in_data->Data());
|
||||
for (int i = 0; i < in_data->ElementsNum(); ++i) {
|
||||
printf("%f ", input_data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
MS_LOG(INFO) << "Print tensor done";
|
||||
}
|
||||
|
||||
TEST_F(TestLeakyReluOpenCL, LeakyReluFp32_dim4) {
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
std::string standard_answer_file = "/data/local/tmp/out_data.bin";
|
||||
MS_LOG(INFO) << "Begin test:";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 4, 3, 8};
|
||||
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
|
||||
auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
|
||||
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
|
||||
// freamework to do!!! allocate memory by hand
|
||||
inputs[0]->MallocData(allocator);
|
||||
|
||||
auto param = new LeakyReluParameter();
|
||||
param->alpha = 0.3;
|
||||
auto *leakyrelu_kernel = new kernel::LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
leakyrelu_kernel->Init();
|
||||
|
||||
MS_LOG(INFO) << "initialize sub_graph";
|
||||
std::vector<kernel::LiteKernel *> kernels{leakyrelu_kernel};
|
||||
auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
sub_graph->Init();
|
||||
|
||||
MS_LOG(INFO) << "initialize input data";
|
||||
LoadDataLeakyRelu(input_tensor->Data(), input_tensor->Size(), in_file);
|
||||
MS_LOG(INFO) << "==================input data================";
|
||||
printf_tensor(inputs[0]);
|
||||
sub_graph->Run();
|
||||
|
||||
MS_LOG(INFO) << "==================output data================";
|
||||
printf_tensor(outputs[0]);
|
||||
CompareOutLeakyRelu(output_tensor, standard_answer_file);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in new issue