fix bug: caffe model run scale

pull/8086/head
chenzupeng 4 years ago
parent 2e667a6353
commit ed7e2e0ab3

@ -5,6 +5,7 @@ __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE |
#define ActType_Relu 1
#define ActType_Sigmod 2
#define ActType_Relu6 3
#define C4NUM 4
__kernel void Scale_IMG(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset,
__write_only image2d_t output, const int2 output_shape, const int act_type) {
@ -63,3 +64,40 @@ __kernel void Scale_C_IMG(__read_only image2d_t input, __read_only image2d_t sca
}
WRITE_IMAGE(output, (int2)(X, Y), out);
}
__kernel void Scale_H_IMG(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset,
__write_only image2d_t output, const int2 output_shape, const int H, const int act_type) {
int X = get_global_id(0);
int Y = get_global_id(1);
if (X >= output_shape.x || Y >= output_shape.y) {
return;
}
int h = Y % H;
int h_quotient = h / C4NUM;
int h_remainder = h % C4NUM;
FLT4 in = READ_IMAGE(input, smp_none, (int2)(X, Y));
FLT4 s = READ_IMAGE(scale, smp_none, (int2)(h_quotient, 0));
FLT4 o = READ_IMAGE(offset, smp_none, (int2)(h_quotient, 0));
FLT s_real;
FLT o_real;
if (h_remainder == 0) {
s_real = s.x;
o_real = o.x;
} else if (h_remainder == 1) {
s_real = s.y;
o_real = o.y;
} else if (h_remainder == 2) {
s_real = s.z;
o_real = o.z;
} else {
s_real = s.w;
o_real = o.w;
}
FLT4 out = in * s_real + o_real;
if (act_type == ActType_Relu) {
out = max(out, (FLT4)(0.0f));
} else if (act_type == ActType_Relu6) {
out = clamp(out, (FLT4)(0.0f), (FLT4)(6.0f));
}
WRITE_IMAGE(output, (int2)(X, Y), out);
}

@ -48,19 +48,12 @@ ScaleOpenCLKernel::~ScaleOpenCLKernel() {
void ScaleOpenCLKernel::Image2dGetWorkGroupSize() {
local_size_ = {16, 16};
if (out_tensors_[0]->shape().size() == 2) {
size_t H = out_tensors_[0]->shape()[0];
size_t W = UP_DIV(out_tensors_[0]->shape()[1], C4NUM);
global_size_ = {W, H};
} else {
size_t H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
size_t W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
global_size_ = {W, H};
}
auto image2d_info = Image2DInfo(out_tensors_[0]);
global_size_ = {image2d_info.width, image2d_info.height};
}
int ScaleOpenCLKernel::InitBuffer() {
if (!element_flag_) {
if (!weight_vector_flag_) {
return RET_OK;
}
if (in_tensors_[1]->IsConst()) {
@ -68,17 +61,18 @@ int ScaleOpenCLKernel::InitBuffer() {
std::vector<size_t> img_size;
GetImageSize(0, &img_size);
img_size[2] = in_tensors_[1]->data_type() == kNumberTypeFloat16 ? CL_HALF_FLOAT : CL_FLOAT;
if (scale_C_flag_) {
if (broadcast_flag_) {
img_size[1] = 1;
img_size[0] = UP_DIV(in_tensors_[1]->shape()[0], C4NUM);
scale_ptr_ = allocator->CreateImageFromHost(in_tensors_[1]->data_c(), in_tensors_[1]->ElementsNum(), img_size);
offset_ptr_ = allocator->CreateImageFromHost(in_tensors_[2]->data_c(), in_tensors_[2]->ElementsNum(), img_size);
return RET_OK;
}
int pack_weight_size = in_tensors_[1]->ElementsC4Num();
int plane = in_tensors_[1]->Height() * in_tensors_[1]->Width();
int channel = in_tensors_[1]->Channel();
int batch = in_tensors_[1]->Batch();
auto image2d_info = Image2DInfo(in_tensors_[1]);
int pack_weight_size = image2d_info.ElementsC4Num;
int plane = image2d_info.H * image2d_info.W;
int channel = image2d_info.C;
int batch = image2d_info.N;
if (in_tensors_[0]->GetFormat() == in_tensors_[1]->GetFormat()) {
if (in_tensors_[0]->data_type() == in_tensors_[1]->data_type()) {
scale_ptr_ = allocator->CreateImageFromHost(in_tensors_[1]->data_c(), in_tensors_[1]->ElementsNum(), img_size);
@ -157,16 +151,27 @@ int ScaleOpenCLKernel::Init() {
}
if (scale_shape.size() != in_shape.size()) {
if (scale_tensor->ElementsNum() == 1) {
element_flag_ = false;
weight_vector_flag_ = false;
kernel_name = "BoardcastScale";
} else if (((in_shape.size() == 4 && axis_ == 3) || (in_shape.size() == 2 && axis_ == 1)) &&
scale_shape.size() == 1) {
element_flag_ = true;
scale_C_flag_ = true;
kernel_name = "Scale_C";
} else if (scale_shape.size() == 1) {
weight_vector_flag_ = true;
broadcast_flag_ = true;
if ((in_shape.size() == 4 && axis_ == 3) || (in_shape.size() == 2 && axis_ == 1)) {
kernel_name = "Scale_C";
} else if (in_shape.size() == 4 && axis_ == 1) {
kernel_name = "Scale_H";
broadcast_H_flag_ = true;
} else {
MS_LOG(ERROR) << "unsupported scale axis " << axis_;
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "unsupported scale axis " << axis_ << ", in shape " << in_shape << ", scale shape"
<< scale_shape;
return RET_ERROR;
}
} else {
element_flag_ = true;
weight_vector_flag_ = true;
kernel_name = "Scale";
}
lite::STATUS error_code;
@ -206,7 +211,7 @@ int ScaleOpenCLKernel::Run() {
int arg_idx = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c());
if (element_flag_) {
if (weight_vector_flag_) {
void *scale = scale_ptr_ == nullptr ? in_tensors_[1]->data_c() : scale_ptr_;
void *offset = offset_ptr_ == nullptr ? in_tensors_[2]->data_c() : offset_ptr_;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, scale);
@ -230,8 +235,12 @@ int ScaleOpenCLKernel::Run() {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
cl_int2 output_shape{static_cast<int>(global_size_[0]), static_cast<int>(global_size_[1])};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
if (element_flag_ && scale_C_flag_) {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, UP_DIV(in_tensors_[1]->shape()[0], C4NUM));
if (weight_vector_flag_ && broadcast_flag_) {
if (broadcast_H_flag_) {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[1]->shape()[0]);
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, UP_DIV(in_tensors_[1]->shape()[0], C4NUM));
}
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, act_type);
ocl_runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr);

@ -38,8 +38,9 @@ class ScaleOpenCLKernel : public OpenCLKernel {
void Image2dGetWorkGroupSize();
cl::Kernel kernel_;
bool element_flag_{true};
bool scale_C_flag_{false};
bool weight_vector_flag_{true};
bool broadcast_flag_{false};
bool broadcast_H_flag_{false};
void *scale_ptr_{nullptr};
void *offset_ptr_{nullptr};
int axis_{0};

Loading…
Cancel
Save