fix prelu scalar weight bug

pull/7206/head
wangdongxu 4 years ago
parent 8d77d4fa90
commit 362daeb2b6

@ -190,8 +190,8 @@ if (PLATFORM_ARM64)
endif ()
if (BUILD_MINDDATA STREQUAL "lite" OR BUILD_MINDDATA STREQUAL "full")
# TODO: add sentencepiece dependency
#include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake)
# add sentencepiece dependency
# include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake)
# opencv
set(OpenCV_DIR ${TOP_DIR}/third_party/opencv/build)
find_package(OpenCV REQUIRED)

@ -96,7 +96,7 @@ endif ()
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND PLATFORM_ARM)
add_custom_command(TARGET mindspore-lite POST_BUILD
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so)
${CMAKE_BINARY_DIR}/src/libmindspore-lite.so)
endif ()
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
@ -124,10 +124,10 @@ endif ()
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64))
add_custom_command(TARGET mindspore-lite-optimize POST_BUILD COMMAND
${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite-optimize.so)
${CMAKE_BINARY_DIR}/src/libmindspore-lite-optimize.so)
add_custom_command(TARGET mindspore-lite-fp16 POST_BUILD COMMAND
${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite-fp16.so)
${CMAKE_BINARY_DIR}/src/libmindspore-lite-fp16.so)
endif ()

@ -1,41 +1,80 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define SLICES 4
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
__read_only image2d_t alpha, const int data_type, const int bias_dim) {
int H = input_shape.y;
int C = input_shape.w; // channel size
C = UP_DIV(C, SLICES);
if (C == 0 || H == 0) {
#define NHWC4 2
#define NC4HW4 100
__kernel void PRelu_scalar(__read_only image2d_t input, __write_only image2d_t output, float weight, int4 shape,
int data_format) {
int h = get_global_id(0);
int w = get_global_id(1);
int slice = get_global_id(2);
int H = shape.y;
int W = shape.z;
int SLICES = shape.w;
if (h >= H || w >= W || slice >= SLICES) {
return;
}
int Y = get_global_id(0); // height id
int X = get_global_id(1); // weight id
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));
FLT4 tmp;
int index = 0;
if (data_type == 1) { // NHWC4
index = X % C;
} else if (data_type == 2) { // NC4HW4
index = Y / H;
int x, y;
if (data_format == 2) {
x = w * SLICES + slice;
y = h;
} else {
x = w;
y = slice * H + h;
}
FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y));
if (out.x < 0) {
out.x *= weight;
}
if (out.y < 0) {
out.y *= weight;
}
if (out.z < 0) {
out.z *= weight;
}
if (out.w < 0) {
out.w *= weight;
}
WRITE_IMAGE(output, (int2)(x, y), out);
}
__kernel void PRelu_vector(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight_vector,
int4 shape, int data_format) {
int h = get_global_id(0);
int w = get_global_id(1);
int slice = get_global_id(2);
int H = shape.y;
int W = shape.z;
int SLICES = shape.w;
if (h >= H || w >= W || slice >= SLICES) {
return;
}
if (bias_dim == 1) {
index = 0;
}
FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(index, 0));
FLT4 bias = weight;
if (bias_dim == 1) {
bias.y = weight.x;
bias.z = weight.x;
bias.w = weight.x;
}
tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * bias.x;
tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * bias.y;
tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * bias.z;
tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * bias.w;
WRITE_IMAGE(output, (int2)(X, Y), tmp);
FLT4 weight = weight_vector[slice];
int x, y;
if (data_format == 2) {
x = w * SLICES + slice;
y = h;
} else {
x = w;
y = slice * H + h;
}
FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y));
if (out.x < 0) {
out.x *= weight.x;
}
if (out.y < 0) {
out.y *= weight.y;
}
if (out.z < 0) {
out.z *= weight.z;
}
if (out.w < 0) {
out.w *= weight.w;
}
WRITE_IMAGE(output, (int2)(x, y), out);
}

@ -359,6 +359,8 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNHWC4() {
code += "#define padBottom " + std::to_string(padBottom) + "\n";
code += "#define padLeft " + std::to_string(padLeft) + "\n";
code += "#define padRight " + std::to_string(padRight) + "\n";
code += "#define dilationH " + std::to_string(param->dilation_h_) + "\n";
code += "#define dilationW " + std::to_string(param->dilation_w_) + "\n";
code += "#define CI_SLICES " + std::to_string(CI_SLICES_) + "\n";
code += "#define CO_SLICES " + std::to_string(CO_SLICES_) + "\n\n";
@ -398,10 +400,10 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNHWC4() {
code +=
" for (int kh = 0; kh < KH; ++kh)\n"
" {\n"
" int ih = kh + oh * strideH - padTop;\n"
" int ih = kh * dilationH + oh * strideH - padTop;\n"
" for (int kw = 0; kw < KW; ++kw)\n"
" {\n"
" int iw = kw + ow * strideW - padLeft;\n"
" int iw = kw * dilationW + ow * strideW - padLeft;\n"
" if (ih >= 0 && ih < IH && iw >= 0 && iw < IW)\n"
" {\n"
" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++)\n"
@ -491,7 +493,9 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() {
code += " #define strideH " + std::to_string(strideH) + "\n";
code += " #define strideW " + std::to_string(strideW) + "\n";
code += " #define padTop " + std::to_string(padTop) + "\n";
code += " #define padLeft " + std::to_string(padLeft) + "\n\n";
code += " #define padLeft " + std::to_string(padLeft) + "\n";
code += " #define dilationH " + std::to_string(param->dilation_h_) + "\n";
code += " #define dilationW " + std::to_string(param->dilation_w_) + "\n";
code +=
" if (n_oh >= N_OH || ow >= OW || co_slice >= CO_SLICES) {\n"
@ -513,7 +517,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() {
"\n"
" for (int kh = 0; kh < KH; ++kh)\n"
" {\n"
" int ih = kh + oh * strideH - padTop;\n"
" int ih = kh * dilationH + oh * strideH - padTop;\n"
" for (int kw = 0; kw < KW; ++kw)\n"
" {\n";
@ -523,7 +527,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() {
"{\n";
}
code += " int iw0 = kw + (ow + 0) * strideW - padLeft;\n";
code += " int iw0 = kw * dilationW + (ow + 0) * strideW - padLeft;\n";
if (check_ow) {
code +=
" if (last_is_double)\n"
@ -531,7 +535,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() {
}
code +=
" int iw1 = kw + (ow + 1) * strideW - padLeft;\n"
" int iw1 = kw * dilationW + (ow + 1) * strideW - padLeft;\n"
" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++)\n"
" {\n"
" FLT4 in0 = READ_IMAGE(input, smp_zero, (int2)(iw0, (n * CI_SLICES + ci_slice) * IH + ih));\n"
@ -916,4 +920,5 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::Tenso
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Conv2D, OpenCLConvolutionKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Conv2D, OpenCLConvolutionKernelCreator)
} // namespace mindspore::kernel

@ -18,7 +18,6 @@
#include <set>
#include <vector>
#include <map>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -36,85 +35,116 @@ namespace mindspore::kernel {
void PReluOpenCLKernel::InitBuffer() {
auto allocator = ocl_runtime_->GetAllocator();
int elem_num = in_tensors_[0]->shape().size() == 2 ? in_tensors_[0]->shape()[1] : in_tensors_[0]->shape()[3];
int elem_num_c4 = UP_DIV(elem_num, C4NUM);
size_t img_dtype = CL_FLOAT;
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
std::vector<size_t> img_size{size_t(elem_num_c4), 1, img_dtype};
PReluWeight_ = allocator->Malloc(elem_num_c4 * C4NUM * fp_size, img_size);
PReluWeight_ = allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true);
memset(PReluWeight_, 0x00, elem_num_c4 * C4NUM * fp_size);
if (enable_fp16_) {
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
auto PReluWeight_fp16 = reinterpret_cast<uint16_t *>(PReluWeight_);
auto in_tensor_data_fp32 = reinterpret_cast<float *>(in_tensors_[1]->data_c());
for (int i = 0; i < elem_num; i++) {
PReluWeight_fp16[i] = static_cast<float16_t>(in_tensor_data_fp32[i]);
}
auto weight_tensor = in_tensors_[1];
if (weight_is_scalar) {
if (weight_tensor->data_type() == kNumberTypeFloat16) {
weight_scalar_ = static_cast<float>(*reinterpret_cast<float16_t *>(weight_tensor->data_c()));
} else {
memcpy(PReluWeight_, in_tensors_[1]->data_c(), elem_num * fp_size);
weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data_c());
}
} else {
if (in_tensors_[1]->data_type() == kNumberTypeFloat16) {
auto PReluWeight_fp32 = reinterpret_cast<float *>(PReluWeight_);
auto in_tensor_data_fp16 = reinterpret_cast<float16_t *>(in_tensors_[1]->data_c());
for (int i = 0; i < elem_num; i++) {
PReluWeight_fp32[i] = static_cast<float>(in_tensor_data_fp16[i]);
auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float);
size_t weight_size = UP_ROUND(C_, C4NUM) * sizeof_FLT;
weight_vector_ = allocator->Malloc(weight_size);
allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true);
memset(weight_vector_, 0x00, weight_size);
if (weight_tensor->data_type() == kNumberTypeFloat16) {
if (enable_fp16_) {
memcpy(weight_vector_, weight_tensor->data_c(), C_ * sizeof_FLT);
} else {
auto weight_fp32 = reinterpret_cast<float *>(weight_vector_);
auto origin_bias_fp16 = reinterpret_cast<float16_t *>(weight_tensor->data_c());
for (int i = 0; i < C_; ++i) {
weight_fp32[i] = static_cast<float>(origin_bias_fp16[i]);
}
}
} else {
memcpy(PReluWeight_, in_tensors_[1]->data_c(), elem_num * fp_size);
if (enable_fp16_) {
auto weight_fp16 = reinterpret_cast<float16_t *>(weight_vector_);
auto origin_bias_fp32 = reinterpret_cast<float *>(weight_tensor->data_c());
for (int i = 0; i < C_; ++i) {
weight_fp16[i] = static_cast<float16_t>(origin_bias_fp32[i]);
}
} else {
memcpy(weight_vector_, weight_tensor->data_c(), C_ * sizeof_FLT);
}
}
allocator->UnmapBuffer(weight_vector_);
}
allocator->UnmapBuffer(PReluWeight_);
}
int PReluOpenCLKernel::Init() {
if (in_tensors_[0]->shape().size() != 4) {
MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << in_tensors_[0]->shape().size();
auto input_tensor = in_tensors_[0];
auto weight_tensor = in_tensors_[1];
if (input_tensor->shape().size() != 4) {
MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << input_tensor->shape().size();
return RET_ERROR;
}
batch_size_ = input_tensor->Batch();
C_ = input_tensor->Channel();
H_ = input_tensor->Height();
W_ = input_tensor->Width();
if (input_tensor->GetFormat() != schema::Format_NC4HW4 && input_tensor->GetFormat() != schema::Format_NHWC4) {
MS_LOG(ERROR) << "PRelu only support Format_NC4HW4 and Format_NHWC4";
return RET_ERROR;
}
int C_Weight = in_tensors_[1]->shape()[0];
int C = in_tensors_[0]->shape()[3];
if (C_Weight != 1 && UP_DIV(C_Weight, C4NUM) != UP_DIV(C, C4NUM)) {
if (batch_size_ != 1) {
MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch.";
return RET_ERROR;
}
auto weight_channel = weight_tensor->shape()[0];
if (weight_channel != 1 && weight_channel != C_) {
MS_LOG(ERROR)
<< "PRelu weight channel size must be 1 or must be equal with in_teneors channel size, but your weight size is "
<< C_Weight << " and your input channel size is " << C;
<< weight_channel << " and your input channel size is " << C_;
return RET_ERROR;
}
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
input_shape_.s[i] = in_tensors_[0]->shape()[i];
weight_is_scalar = weight_channel == 1;
if (weight_tensor->data_type() != kNumberTypeFloat16 && weight_tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "PRelu weight must be float32 or float16";
return RET_ERROR;
}
enable_fp16_ = ocl_runtime_->GetFp16Enable();
in_ori_format_ = input_tensor->GetFormat();
out_ori_format_ = out_tensors_[0]->GetFormat();
input_tensor->SetFormat(op_format_);
out_tensors_[0]->SetFormat(op_format_);
std::set<std::string> build_options;
std::string source = prelu_source;
std::string program_name = "PRelu";
std::string kernel_name = "PRelu";
enable_fp16_ = ocl_runtime_->GetFp16Enable();
fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
InitBuffer();
std::string kernel_name = "PRelu_" + std::string(weight_is_scalar ? "scalar" : "vector");
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(op_format_);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(op_format_);
InitBuffer();
MS_LOG(DEBUG) << program_name << " init Done!";
return RET_OK;
}
int PReluOpenCLKernel::Run() {
MS_LOG(DEBUG) << op_parameter_->name_ << " Running!";
std::map<schema::Format, int> data_type{{schema::Format::Format_NHWC4, 1}, {schema::Format::Format_NC4HW4, 2}};
auto CO_SLICES_ = UP_DIV(C_, C4NUM);
cl_int4 shape = {batch_size_, H_, W_, CO_SLICES_};
int arg_idx = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_shape_);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, PReluWeight_);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, data_type[op_format_]);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<int>(in_tensors_[1]->shape()[0]));
std::vector<size_t> local = {1, 1};
std::vector<size_t> global = {static_cast<size_t>(global_shape_.s[1]), static_cast<size_t>(global_shape_.s[2])};
if (weight_is_scalar) {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_scalar_);
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_);
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, shape);
if (op_format_ == schema::Format_NHWC4) {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 2);
} else { // Format_NC4HW4 = 100
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 100);
}
std::vector<size_t> local = {4, 4, 1};
std::vector<size_t> global = {static_cast<size_t>(H_), static_cast<size_t>(W_), static_cast<size_t>(CO_SLICES_)};
auto ret = ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error.";
@ -124,22 +154,26 @@ int PReluOpenCLKernel::Run() {
}
int PReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t img_dtype = CL_FLOAT;
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
global_shape_ = input_shape_;
if (op_format_ == schema::Format::Format_NC4HW4) {
global_shape_.s[1] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[1];
} else if (op_format_ == schema::Format::Format_NHWC4) {
global_shape_.s[2] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[2];
size_t im_dst_x, im_dst_y;
auto CO_SLICES_ = UP_DIV(C_, C4NUM);
if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
if (W_ * CO_SLICES_ <= MAX_IMAGE2D_SIZE) {
{
im_dst_y = batch_size_ * H_;
im_dst_x = W_ * CO_SLICES_;
}
} else {
im_dst_y = W_;
im_dst_x = batch_size_ * H_ * CO_SLICES_;
}
} else {
MS_LOG(ERROR) << "op_format_:" << op_format_ << " is do not support!";
return RET_ERROR;
im_dst_y = batch_size_ * CO_SLICES_ * H_;
im_dst_x = W_;
}
size_t img_dtype = enable_fp16_ ? CL_HALF_FLOAT : CL_FLOAT;
img_size->clear();
img_size->push_back(global_shape_.s[2]);
img_size->push_back(global_shape_.s[1]);
img_size->push_back(im_dst_x);
img_size->push_back(im_dst_y);
img_size->push_back(img_dtype);
return RET_OK;
}
@ -152,16 +186,11 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::Tensor *> &
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
return nullptr;
}
if (inputs[0]->shape()[0] > 1) {
MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch.";
return nullptr;
}
auto *kernel = new (std::nothrow) PReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init PRelu kernel failed!";

@ -39,11 +39,14 @@ class PReluOpenCLKernel : public OpenCLKernel {
private:
cl::Kernel kernel_;
void *PReluWeight_;
cl_int4 input_shape_;
cl_int4 global_shape_;
size_t fp_size;
bool enable_fp16_{false};
int batch_size_{};
int C_{};
int H_{};
int W_{};
void *weight_vector_{nullptr};
float weight_scalar_{0.f};
bool weight_is_scalar{false};
};
} // namespace mindspore::kernel

@ -235,14 +235,16 @@ void PrintTensor(lite::Tensor *tensor, int num, const std::string &out_file) {
if (tensor->data_c() == nullptr) {
return;
}
auto runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->SyncCommandQueue();
auto allocator = runtime->GetAllocator();
auto origin_data = tensor->data_c();
allocator->MapBuffer(origin_data, CL_MAP_READ, nullptr, true);
allocator->MapBuffer(origin_data, CL_MAP_READ | CL_MAP_WRITE, nullptr, true);
tensor->SetData(origin_data);
auto Batch = tensor->Batch();
auto Height = tensor->shape().size() == 4 ? tensor->Height() : 1;
auto Width = tensor->shape().size() == 4 ? tensor->Width() : 1;
auto SLICES = UP_DIV(tensor->Channel(), C4NUM);
@ -250,17 +252,8 @@ void PrintTensor(lite::Tensor *tensor, int num, const std::string &out_file) {
auto dtype_size = tensor->data_type() == kNumberTypeFloat16 ? sizeof(cl_half4) : sizeof(cl_float4);
auto row_pitch = (Width * SLICES + alignment - 1) / alignment * alignment * dtype_size;
auto row_size = Width * SLICES * dtype_size;
std::cout << "tensor->GetFormat() =" << tensor->GetFormat() << "\n";
std::cout << "Height =" << Height << "\n";
std::cout << "Width =" << Width << "\n";
std::cout << "SLICES =" << SLICES << "\n";
std::cout << "image_alignment =" << alignment << "\n";
std::cout << "dtype_size =" << dtype_size << "\n";
std::cout << "row_pitch =" << row_pitch << "\n";
std::cout << "row_size =" << row_size << "\n";
std::cout << "tensor->Size() =" << tensor->Size() << "\n";
std::vector<char> data(tensor->Size());
for (int i = 0; i < Height; ++i) {
for (int i = 0; i < Batch * Height; ++i) {
memcpy(static_cast<char *>(data.data()) + i * row_size, static_cast<char *>(origin_data) + i * row_pitch, row_size);
}

Loading…
Cancel
Save