|
|
|
@ -50,6 +50,12 @@ std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const {
|
|
|
|
|
|
|
|
|
|
void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() {
|
|
|
|
|
local_size_ = {16, 16};
|
|
|
|
|
if (out_tensors_[0]->shape().size() == 2) {
|
|
|
|
|
size_t H = out_tensors_[0]->shape()[0];
|
|
|
|
|
size_t W = UP_DIV(out_tensors_[0]->shape()[1], C4NUM);
|
|
|
|
|
global_size_ = {W, H};
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
|
|
|
|
|
size_t H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
|
|
|
|
size_t W = out_tensors_[0]->Width();
|
|
|
|
@ -74,6 +80,10 @@ void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() {
|
|
|
|
|
|
|
|
|
|
int ArithmeticOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
|
|
|
|
size_t im_dst_x, im_dst_y;
|
|
|
|
|
if (out_tensors_[0]->shape().size() == 2) {
|
|
|
|
|
im_dst_x = UP_DIV(out_tensors_[0]->shape()[1], C4NUM);
|
|
|
|
|
im_dst_y = out_tensors_[0]->shape()[0];
|
|
|
|
|
} else {
|
|
|
|
|
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
|
|
|
|
|
im_dst_x = out_tensors_[0]->Width();
|
|
|
|
|
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
|
|
|
@ -87,6 +97,7 @@ int ArithmeticOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_si
|
|
|
|
|
MS_LOG(ERROR) << "Unsupport data format " << out_tensors_[0]->GetFormat();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t img_dtype = CL_FLOAT;
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat16) {
|
|
|
|
@ -335,22 +346,7 @@ int ArithmeticOpenCLKernel::Run() {
|
|
|
|
|
}
|
|
|
|
|
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
|
|
|
|
|
|
|
|
|
|
int H = 0;
|
|
|
|
|
int W = 0;
|
|
|
|
|
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
|
|
|
|
|
H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
|
|
|
|
W = out_tensors_[0]->Width();
|
|
|
|
|
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
|
|
|
|
|
H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
|
|
|
|
|
W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
|
|
|
|
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
|
|
|
|
|
H = out_tensors_[0]->Batch();
|
|
|
|
|
W = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Error output type " << out_tensors_[0]->GetFormat();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
cl_int2 output_shape{W, H};
|
|
|
|
|
cl_int2 output_shape{static_cast<int>(global_size_[0]), static_cast<int>(global_size_[1])};
|
|
|
|
|
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
|
|
|
|
|
ocl_runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr);
|
|
|
|
|
return RET_OK;
|
|
|
|
|