add fp16 for opencl depthwise

pull/5143/head
wandongdong 5 years ago
parent 167dc5f09e
commit b972ea6262

@ -1,15 +1,5 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void to_format_NCHW_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NHWC_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
@ -47,58 +37,17 @@ __kernel void to_format_NHWC4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only
}
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), src_data[(X * size.y + Y) * size.z + Z]);
}
__kernel void to_format_NC4HW4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NCHW_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NHWC_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NHWC4_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NC4HW4_to_NC4HW4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
// size(h, w, c4, 1), shape(n, c, h, w)
int X = get_global_id(0); // h
int Y = get_global_id(1); // w
int Z = get_global_id(2); // c4
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// FLT4 src_final = src_data[(((Z)*src_size.y + (y_c)) * src_size.x + (x_c))];
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), src_data[(Y * size.z + Z) * size.x + X]);
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), src_data[(Z * size.x + X) * size.y + Y]);
}
__kernel void to_format_NCHW_to_NCHW_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
@ -109,56 +58,6 @@ __kernel void to_format_NCHW_to_NCHW_BUF(__read_only image2d_t src_data, __globa
}
dst_data[(Z * size.y + Y) * size.x + X] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.x + X, Z));
}
__kernel void to_format_NHWC_to_NCHW_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NHWC4_to_NCHW_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NC4HW4_to_NCHW_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NCHW_to_NHWC_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NHWC_to_NHWC_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NHWC4_to_NHWC_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
@ -185,25 +84,16 @@ __kernel void to_format_NHWC4_to_NHWC_BUF(__read_only image2d_t src_data, __glob
}
}
}
__kernel void to_format_NC4HW4_to_to_NHWC_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
}
__kernel void to_format_NC4HW4_to_NC4HW4_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
// size(h, w, c, 1), shape(n, c, h, w)
int X = get_global_id(0); // h
int Y = get_global_id(1); // w
int Z = get_global_id(2); // c
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
dst_data[(Y * size.z + Z) * size.x + X] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X));
dst_data[(Z * size.x + X) * size.y + Y] = READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X));
}
__kernel void to_format_NHWC4_to_NHWC4_BUF(__read_only image2d_t src_data, __global FLT4 *dst_data, int4 size,
int4 shape) {

@ -20,9 +20,10 @@
#include <utility>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "nnacl/fp32/common_func.h"
#include "nnacl/op_base.h"
#include "include/errorcode.h"
#include "nnacl/pack.h"
#ifndef PROGRAM_WITH_IL
@ -81,30 +82,50 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_);
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto allocator = ocl_runtime->GetAllocator();
bool is_fp16 = ocl_runtime->GetFp16Enable();
// weight: o, h, w, i; o == group, i == 1
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_.at(kWeightIndex)->Data());
void *origin_weight = in_tensors_.at(kWeightIndex)->Data();
int CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_;
packed_weight_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(pack_weight_size * sizeof(FLOAT_t)));
packed_weight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true));
int plane = parameter->kernel_h_ * parameter->kernel_w_;
#ifdef ENABLE_FP16
PackNCHWToNC4HW4Fp16(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel());
#else
PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel());
#endif
if (is_fp16) {
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(int16_t));
packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true);
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) {
std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; };
PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) {
std::function<int16_t(float)> to_dtype = Float32ToShort;
PackNCHWToNC4HW4<float, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
}
} else {
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float));
packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true);
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) {
std::function<float(float)> to_dtype = [](float x) -> float { return (float)x; };
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
}
}
allocator->UnmapBuffer(packed_weight_);
if (in_tensors_.size() == kInputSize2) {
bias_data_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(C4NUM * CO4 * sizeof(FLOAT_t)));
bias_data_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(bias_data_, CL_MAP_WRITE, nullptr, true));
size_t up_co_size = C4NUM * CO4 * sizeof(FLOAT_t);
size_t dtype_size = sizeof(float);
if (is_fp16 && in_tensors_.at(kBiasIndex)->data_type() == kNumberTypeFloat16) {
dtype_size = sizeof(int16_t);
}
bias_data_ = allocator->Malloc(C4NUM * CO4 * dtype_size);
bias_data_ = allocator->MapBuffer(bias_data_, CL_MAP_WRITE, nullptr, true);
size_t up_co_size = C4NUM * CO4 * dtype_size;
memset(bias_data_, 0, up_co_size);
auto ori_bias = reinterpret_cast<FLOAT_t *>(in_tensors_.at(kBiasIndex)->Data());
memcpy(bias_data_, ori_bias, out_tensors_[0]->Channel() * sizeof(FLOAT_t));
auto ori_bias = in_tensors_.at(kBiasIndex)->Data();
memcpy(bias_data_, ori_bias, out_tensors_[0]->Channel() * dtype_size);
allocator->UnmapBuffer(bias_data_);
} else {
MS_ASSERT(in_tensors_.size() == kInputSize1);
@ -124,11 +145,10 @@ int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *i
im_dst_y = out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
}
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
if (lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable()) {
img_dtype = CL_HALF_FLOAT;
}
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
@ -204,5 +224,6 @@ kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector<lite::t
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_DepthwiseConv2D, OpenCLDepthwiseConv2dKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, OpenCLDepthwiseConv2dKernelCreator)
} // namespace mindspore::kernel

@ -20,7 +20,6 @@
#include <vector>
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "nnacl/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
namespace mindspore::kernel {
@ -46,8 +45,8 @@ class DepthwiseConv2dOpenCLKernel : public OpenCLKernel {
int GetLocalSize(size_t idx, const std::vector<size_t> &global_size, std::vector<size_t> *local_size) override;
private:
FLOAT_t *packed_weight_;
FLOAT_t *bias_data_;
void *packed_weight_;
void *bias_data_;
cl::Kernel kernel_;
};
} // namespace mindspore::kernel

@ -172,5 +172,6 @@ kernel::LiteKernel *OpenCLToFormatKernelCreator(const std::vector<lite::tensor::
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ToFormat, OpenCLToFormatKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ToFormat, OpenCLToFormatKernelCreator)
} // namespace mindspore::kernel

@ -93,11 +93,10 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
}
out_tensors->emplace_back(new_tensor);
#ifdef ENABLE_FP16
KernelKey desc{kGPU, kNumberTypeFloat16, schema::PrimitiveType_ToFormat};
#else
KernelKey desc{kGPU, kNumberTypeFloat32, schema::PrimitiveType_ToFormat};
#endif
if (lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable()) {
desc.data_type = kNumberTypeFloat16;
}
OpenCLToFormatParameter *parameter = new (std::nothrow) OpenCLToFormatParameter;
MS_ASSERT(parameter);
if (parameter == nullptr) {

@ -23,6 +23,7 @@
#include "utils/log_adapter.h"
#include "nnacl/op_base.h"
#include "src/lite_kernel.h"
#include "src/common//utils.h"
namespace mindspore::lite {
kernel::LiteKernel *GetOpenCLKernel(const std::vector<tensor::Tensor *> &in_tensors,
@ -89,6 +90,73 @@ std::vector<size_t> GetCommonLocalSize(const std::vector<size_t> &global, int ma
std::string CLErrorCode(cl_int error_code);
template <class T1, class T2>
void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane, int channel,
const std::function<T2(T1)> &to_dtype) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c4 * C4NUM;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_rem = c % C4NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem;
(static_cast<T2 *>(dst) + dst_kernel_offset)[0] =
to_dtype((static_cast<T1 *>(src) + src_kernel_offset)[0]);
}
}
}
}
template <class T1, class T2>
void PackNHWCToNHWC4(void *src, void *dst, int batch, int plane, int channel,
const std::function<T2(T1)> &to_dtype) {
int c4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
int ic_remainder_ = channel % C4NUM;
if (ic_remainder_ != 0) {
int nhwc4_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; ++i) {
for (int c = 0; c < channel; ++c) {
(static_cast<T2 *>(dst) + nhwc4_batch_offset + i * c4 * C4NUM + c)[0] =
to_dtype((static_cast<T1 *>(src) + batch_offset + i * channel + c)[0]);
}
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel;
for (size_t n = 0; n < ori_input_size; ++n) {
(static_cast<T2 *>(dst) + n)[0] = to_dtype((static_cast<T1 *>(src) + n)[0]);
}
}
}
template <class T1, class T2>
void PackNHWCToNC4HW4(void *src, void *dst, int batch, int plane, int channel,
const std::function<T2(T1)> &to_dtype) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_oc_offset = b * plane * channel;
int dst_oc_offset = b * plane * c4 * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_oc_offset + k * channel;
int dst_kernel_offset = dst_oc_offset + k * C4NUM;
for (int i = 0; i < channel; i++) {
int c4_block_num = i / C4NUM;
int c4_block_rem = i % C4NUM;
int src_ic_offset = src_kernel_offset + i;
int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem;
(static_cast<T2 *>(dst) + dst_ic_offset)[0] = to_dtype((static_cast<T1 *>(src) + src_ic_offset)[0]);
}
}
}
}
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_

Loading…
Cancel
Save