fix bug in arithmetic

pull/7056/head
chenzupeng 4 years ago
parent 4f754daccf
commit 5d82670532

@ -15,6 +15,19 @@ __kernel void ElementAdd_IMG(__read_only image2d_t input_a, __read_only image2d_
WRITE_IMAGE(output, (int2)(X, Y), a + b);
}
__kernel void ElementAddReLU_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int2 output_shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
if (X >= output_shape.x || Y >= output_shape.y) {
return;
}
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));
WRITE_IMAGE(output, (int2)(X, Y), max(a + b, (FLT4)(0.f)));
}
__kernel void ElementSub_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int2 output_shape) {
int X = get_global_id(0);

@ -262,6 +262,22 @@ int ArithmeticOpenCLKernel::Init() {
return RET_ERROR;
}
switch (arithmetic_parameter->activation_type_) {
case schema::ActivationType_NO_ACTIVATION:
break;
case schema::ActivationType_RELU:
if (op_parameter_->type_ == PrimitiveType_Add && element_flag_) {
kernel_name += "ReLU";
} else {
MS_LOG(ERROR) << "Only support ElementAdd + ReLU";
return RET_ERROR;
}
break;
default:
MS_LOG(ERROR) << "Error activation type " << arithmetic_parameter->activation_type_;
return RET_ERROR;
}
lite::STATUS error_code = RET_OK;
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);

@ -115,8 +115,8 @@ int ScaleOpenCLKernel::InitBuffer() {
std::vector<size_t> img_size;
GetImageSize(0, &img_size);
if (in_tensors_[1]->shape().size() == 1 && axis_ == 3) {
img_size[0] = 1;
img_size[1] = UP_DIV(in_tensors_[1]->shape()[0], C4NUM);
img_size[1] = 1;
img_size[0] = UP_DIV(in_tensors_[1]->shape()[0], C4NUM);
scale_ptr_ = allocator->CreateImageFromHost(in_tensors_[1]->data_c(), in_tensors_[1]->ElementsNum(), img_size);
offset_ptr_ = allocator->CreateImageFromHost(in_tensors_[2]->data_c(), in_tensors_[2]->ElementsNum(), img_size);
return RET_OK;

Loading…
Cancel
Save