!7813 [MS][LITE][Develop] add new ops for GPU named hswish

Merge pull request !7813 from pengyongrong/stack
pull/7813/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e7a6ae63bc

@ -11,13 +11,14 @@ __kernel void FullConnection_NHWC4(__read_only image2d_t input, __global FLT16 *
int lidy = get_local_id(1);
int ci4 = UP_DIV(in_shape.w, C4NUM);
int hwci4 = ci4 * in_shape.y * in_shape.z;
int wci4 = ci4 * in_shape.z;
int co4 = UP_DIV(out_shape.y, C4NUM);
int n = out_shape.x;
if (gidx >= co4 || gidz >= n) return;
bool inside = gidx < co4 && gidz < n;
FLT4 result = (FLT4)(0.0f);
for (uint i = lidy; i < hwci4; i += 4) {
int index_h = i / (ci4 * in_shape.z);
int index_wci4 = i % (ci4 * in_shape.z);
for (uint i = lidy; i < hwci4 && inside; i += 4) {
int index_h = i / wci4;
int index_wci4 = i % wci4;
FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_wci4, gidz * in_shape.y + index_h));
FLT16 w = weight[i * co4 + gidx];
result.x += dot(v, w.s0123);
@ -25,13 +26,13 @@ __kernel void FullConnection_NHWC4(__read_only image2d_t input, __global FLT16 *
result.z += dot(v, w.s89ab);
result.w += dot(v, w.scdef);
}
__local FLT4 temp[4];
temp[lidy] = result;
__local FLT4 temp[32][4];
temp[lidx][lidy] = result;
barrier(CLK_LOCAL_MEM_FENCE);
if (lidy == 0) {
result += temp[1];
result += temp[2];
result += temp[3];
if (lidy == 0 && inside) {
result += temp[lidx][1];
result += temp[lidx][2];
result += temp[lidx][3];
result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(gidx, gidz), result);

@ -0,0 +1,19 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void hswish(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 tensor_shape) {
int X = get_global_id(0); // n*h n: default =1
int Y = get_global_id(1); // w
int Z = get_global_id(2); // c
if (X >= tensor_shape.x * tensor_shape.y || Y >= tensor_shape.z || Z >= tensor_shape.w || tensor_shape.y == 0) {
return;
}
int n = X / tensor_shape.y;
int h = X % tensor_shape.y;
FLT4 temp = READ_IMAGE(src_data, smp_none, (int2)((Y)*tensor_shape.w + Z, (n * tensor_shape.y + h)));
FLT4 result = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
result.x = temp.x <= -3 ? 0 : (temp.x >= 3 ? 1 : temp.x / 6 + 0.5f);
result.y = temp.y <= -3 ? 0 : (temp.y >= 3 ? 1 : temp.y / 6 + 0.5f);
result.z = temp.z <= -3 ? 0 : (temp.z >= 3 ? 1 : temp.z / 6 + 0.5f);
result.w = temp.w <= -3 ? 0 : (temp.w >= 3 ? 1 : temp.w / 6 + 0.5f);
WRITE_IMAGE(dst_data, (int2)((Y)*tensor_shape.w + Z, (n * tensor_shape.y + h)), result);
}

@ -171,7 +171,7 @@ void FullConnectionOpenCLKernel::PadWeight() {
int FullConnectionOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
std::vector<size_t> local = {1, 4, 1};
std::vector<size_t> local = {32, 4, 1};
std::vector<size_t> global = {UP_DIV(outShape.C, C4NUM), 4, outShape.N};
int arg_count = 0;
cl_int4 in_shape = {static_cast<int>(inShape.N), static_cast<int>(inShape.H), static_cast<int>(inShape.W),

@ -0,0 +1,128 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/opencl/kernel/hswish.h"
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/cl/hswish.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Activation;
namespace mindspore::kernel {
int HswishOpenCLKernel::Init() {
if (out_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << " only support dim <= 4";
return RET_ERROR;
}
std::string kernel_name = "hswish";
std::set<std::string> build_options;
std::string source = hswish_source;
std::string program_name = "hswish";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK;
}
void HswishGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
const int max_x = 2, max_y = 8;
int x = std::min(GetMaxDivisorStrategy1(global[0], max_divider), max_x);
int yz = max_size / x;
int y = std::min(std::min(GetMaxDivisorStrategy1(global[1], max_divider), yz), max_y);
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));
local->clear();
local->push_back(x);
local->push_back(y);
local->push_back(z);
}
int HswishOpenCLKernel::InferShapeTo4D() {
if (in_tensors_[0]->shape().size() <= 4) {
if (in_tensors_[0]->shape().size() == 1) {
N_ = in_tensors_[0]->shape()[0];
} else if (in_tensors_[0]->shape().size() == 2) {
N_ = in_tensors_[0]->shape()[0];
C_ = in_tensors_[0]->shape()[1];
} else if (in_tensors_[0]->shape().size() == 3) {
N_ = in_tensors_[0]->shape()[0];
W_ = in_tensors_[0]->shape()[1];
C_ = in_tensors_[0]->shape()[2];
} else {
N_ = in_tensors_[0]->shape()[0];
H_ = in_tensors_[0]->shape()[1];
W_ = in_tensors_[0]->shape()[2];
C_ = in_tensors_[0]->shape()[3];
}
} else {
MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size();
return RET_ERROR;
}
return RET_OK;
}
int HswishOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto output_shape = out_tensors_[0]->shape();
InferShapeTo4D();
cl_int4 output_shape_ = {static_cast<cl_int>(N_), static_cast<cl_int>(H_), static_cast<cl_int>(W_),
static_cast<cl_int>(UP_DIV(C_, C4NUM))};
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1};
uint32_t OH = N_ * H_;
uint32_t OW = W_;
uint32_t OC = UP_DIV(C_, C4NUM);
std::vector<size_t> global = {OH, OW, OC};
HswishGetWorkGroup(global, &local, max_global[0]);
int arg_cn = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}
kernel::LiteKernel *HswishOpenCLKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) HswishOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << " new HswishOpenCLKernel failed ";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << " Init kernel failed, name: hswish ";
delete kernel;
return nullptr;
}
return kernel;
}
} // namespace mindspore::kernel

@ -0,0 +1,50 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_HSWISH_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_HSWISH_H_
#include <vector>
#include "mindspore/lite/nnacl/fp32/activation.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
namespace mindspore::kernel {
class HswishOpenCLKernel : public OpenCLKernel {
public:
HswishOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~HswishOpenCLKernel() override = default;
int Init() override;
int Run() override;
private:
int InferShapeTo4D();
cl::Kernel kernel_;
private:
size_t N_{1};
size_t H_{1};
size_t W_{1};
size_t C_{1};
};
} // namespace mindspore::kernel
#endif

@ -48,22 +48,6 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::Tensor *> &in_te
}
}
for (size_t i = 0; i < in_tensors.size(); ++i) {
if (in_tensors.at(i)->shape().size() <= 1) {
if (mem_type == OpenCLMemType::IMG) {
for (auto &iv : in_kernels[i]) {
auto tensors = iv->in_tensors();
tensors.emplace_back(in_tensors.at(i));
iv->set_in_tensors(tensors);
}
} else {
for (auto &iv : in_kernels[i]) {
auto tensors = iv->out_tensors();
tensors.emplace_back(in_tensors.at(i));
iv->set_out_tensors(tensors);
}
}
continue;
}
auto dst_format = (mem_type == OpenCLMemType::IMG) ? schema::Format::Format_NHWC4 : schema::Format::Format_NHWC;
auto src_format = (mem_type == OpenCLMemType::IMG) ? schema::Format::Format_NHWC : schema::Format::Format_NHWC4;
auto *new_tensor = new (std::nothrow) lite::Tensor();

@ -16,6 +16,7 @@ mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
mtk_276landmark_0913.tflite
mtk_face_recognition.tflite
mtk_convert_model.tflite
mtk_model_face_dress_fp16.tflite
detection_retinaface_fix
landmark
PoseNet_dla_17_x512

@ -130,3 +130,4 @@ mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
mtk_276landmark_0913.tflite
mtk_face_recognition.tflite
mtk_convert_model.tflite
mtk_model_face_dress_fp16.tflite

@ -47,3 +47,4 @@ cp -fr $TEST_DATA_DIR/testPK ./data
./lite-test --gtest_filter="TestScaleOpenCL*"
./lite-test --gtest_filter="TestFullConnectionOpenCL*"
./lite-test --gtest_filter="TestResizeOpenCL*"
./lite-test --gtest_filter="TestSwishOpenCLCI.Fp32CI"

@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/hswish.h"
using mindspore::lite::Tensor;
using mindspore::schema::Format::Format_NHWC;
namespace mindspore {
class TestSwishOpenCLCI : public mindspore::CommonTest {
public:
TestSwishOpenCLCI() {}
};
TEST_F(TestSwishOpenCLCI, Fp32CI) {
MS_LOG(INFO) << " begin test ";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->Init();
auto allocator = runtime->GetAllocator();
MS_LOG(INFO) << " init tensors ";
std::vector<int> input_shape = {2, 10, 1, 4};
std::vector<int> output_shape = {2, 10, 1, 4};
auto data_type = kNumberTypeFloat32;
auto tensor_type = lite::Tensor::CONST_TENSOR;
float input_data[] = {2.5f, 6.0f, -7.4f, -3.5f, 5.9f, 6.5f, -8.0f, 7.4f, 5.9f, 6.5f, -8.0f, 7.4f, 7.5f, 6.0f,
-7.4f, -3.5f, 7.5f, 6.0f, -7.4f, -3.5f, 5.9f, 6.5f, -8.0f, 7.4f, 5.9f, 6.5f, -8.0f, 7.4f,
7.5f, 6.0f, -7.4f, -3.5f, 7.5f, 6.0f, -7.4f, -3.5f, 5.9f, 6.5f, -8.0f, 7.4f, 5.9f, 6.5f,
-8.0f, 7.4f, 7.5f, 6.0f, -7.4f, -3.5f, 7.5f, 6.0f, -7.4f, -3.5f, 5.9f, 6.5f, -8.0f, 7.4f,
5.9f, 6.5f, -8.0f, 7.4f, 7.5f, 6.0f, -7.4f, -3.5f, 7.5f, 6.0f, -7.4f, -3.5f, 5.9f, 6.5f,
-8.0f, 7.4f, 5.9f, 6.5f, -8.0f, 7.4f, 7.5f, 6.0f, -7.4f, -3.5f};
float correctOutput[] = {0.9167f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f,
0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f,
1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f,
0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f,
1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f};
auto output_tensor = Tensor(data_type, input_shape, Format_NHWC, tensor_type);
auto in_tensor = Tensor(data_type, output_shape, Format_NHWC, tensor_type);
std::vector<lite::Tensor *> inputs{&in_tensor};
std::vector<lite::Tensor *> outputs{&output_tensor};
MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ActivationParameter failed ";
return;
}
auto *hswish_kernel =
new (std::nothrow) kernel::HswishOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (hswish_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::HswishOpenCLKernel failed ";
delete param;
return;
}
hswish_kernel->Init();
// to do allocate memory for inputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{hswish_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
delete param;
delete hswish_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
memcpy(inputs[0]->data_c(), input_data, sizeof(input_data));
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c());
CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001);
delete sub_graph;
}
} // namespace mindspore
Loading…
Cancel
Save