!8379 [MS][LITE][Develop] add new ops named sparse_to_dense , shape and fill for GPU

From: @pengyongrong
Reviewed-by: @ddwsky,@zhanghaibo5
Signed-off-by: @ddwsky
pull/8379/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 651a19589f

@ -0,0 +1,131 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#define C4NUM 4
__kernel void SparseToDenseScalarDim0(__read_only image2d_t input, __write_only image2d_t output, float weight,
int2 input_shape, float default_value) {
FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, 0));
FLT4 result = {default_value, default_value, default_value, default_value};
int integer = index_input.x / C4NUM;
int decimal = (int)(index_input.x) % C4NUM;
if (decimal == 0) {
result.x = weight;
} else if (decimal == 1) {
result.y = weight;
} else if (decimal == 2) {
result.z = weight;
} else {
result.w = weight;
}
WRITE_IMAGE(output, (int2)(0, integer), result);
return;
}
__kernel void SparseToDenseScalarDim1(__read_only image2d_t input, __write_only image2d_t output, float weight,
int2 input_shape, float default_value) {
for (int i = 0; i < input_shape.x; ++i) {
FLT4 result = READ_IMAGE(input, smp_zero, (int2)(0, i));
int Y = result.x;
result.x = weight;
WRITE_IMAGE(output, (int2)(0, Y), result);
}
}
__kernel void SparseToDenseVectorDim1(__read_only image2d_t input, __write_only image2d_t output,
__global float *weight, int2 input_shape, float default_value) {
int index_weight = 0;
for (int i = 0; i < input_shape.x; ++i) {
FLT4 result = READ_IMAGE(input, smp_zero, (int2)(0, i));
int Y = result.x;
result.x = weight[index_weight++];
WRITE_IMAGE(output, (int2)(0, Y), result);
}
}
__kernel void SparseToDenseScalarDim2Shape2(__read_only image2d_t input, __write_only image2d_t output, float weight,
int2 input_shape, float default_value) {
FLT temp[8] = {default_value, default_value, default_value, default_value,
default_value, default_value, default_value, default_value};
FLT result_temp[8] = {default_value, default_value, default_value, default_value,
default_value, default_value, default_value, default_value};
int index = 0; // 0~4
int X = 0;
FLT4 index_begin = READ_IMAGE(input, smp_zero, (int2)(0, 0));
int Y = (int)index_begin.x; // N
temp[index] = index_begin.y; // c/4
for (int i = 1; i < input_shape.x && index < C4NUM; ++i) {
FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, i));
if ((((int)temp[index]) / C4NUM == ((int)index_input.y) / C4NUM) && (Y == (int)index_input.x)) {
index++;
if (index < C4NUM) {
temp[index] = index_input.y;
}
} else {
for (int j = 0; j <= index && index < C4NUM; ++j) {
int decimal = (int)temp[j] % C4NUM;
result_temp[decimal] = weight;
X = ((int)temp[0]) / C4NUM;
}
FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]};
WRITE_IMAGE(output, (int2)(X, Y), result);
index = 0;
Y = (int)index_input.x;
temp[0] = index_input.y;
temp[1] = temp[2] = temp[3] = default_value;
result_temp[0] = result_temp[1] = result_temp[2] = result_temp[3] = default_value;
}
}
// judge the last element for input
X = ((int)temp[0]) / C4NUM;
for (int i = 0; i <= index && index < C4NUM; ++i) {
int decimal = (int)temp[i] % C4NUM;
result_temp[decimal] = weight;
}
FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]};
WRITE_IMAGE(output, (int2)(X, Y), result);
}
__kernel void SparseToDenseVectorDim2Shape2(__read_only image2d_t input, __write_only image2d_t output,
__global float *weight, int2 input_shape, float default_value) {
FLT temp[8] = {default_value, default_value, default_value, default_value,
default_value, default_value, default_value, default_value};
FLT result_temp[8] = {default_value, default_value, default_value, default_value,
default_value, default_value, default_value, default_value};
int index = 0; // 0~4
int weight_index = 0;
int X = 0;
FLT4 index_begin = READ_IMAGE(input, smp_zero, (int2)(0, 0));
int Y = (int)index_begin.x; // N
temp[index] = index_begin.y; // c/4
for (int i = 1; i < input_shape.x && index < C4NUM; ++i) {
FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, i));
if ((((int)temp[index]) / C4NUM == ((int)index_input.y) / C4NUM) && (Y == (int)index_input.x)) {
index++;
if (index < C4NUM) {
temp[index] = index_input.y;
}
} else {
for (int j = 0; j <= index && index < C4NUM; ++j) {
int decimal = (int)temp[j] % C4NUM;
result_temp[decimal] = weight[weight_index++];
X = ((int)temp[0]) / C4NUM;
}
FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]};
WRITE_IMAGE(output, (int2)(X, Y), result);
index = 0;
Y = (int)index_input.x;
temp[0] = index_input.y;
temp[1] = temp[2] = temp[3] = default_value;
result_temp[0] = result_temp[1] = result_temp[2] = result_temp[3] = default_value;
}
}
// judge the last element for input
X = ((int)temp[0]) / C4NUM;
for (int i = 0; i <= index && index < C4NUM; ++i) {
int decimal = (int)temp[i] % C4NUM;
result_temp[decimal] = weight[weight_index++];
}
FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]};
WRITE_IMAGE(output, (int2)(X, Y), result);
}

@ -0,0 +1,115 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/opencl/kernel/fill.h"
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Fill;
using mindspore::schema::PrimitiveType_Shape;
namespace mindspore::kernel {
int FillOpenCLKernel::RunFill() {
auto allocator_ = ocl_runtime_->GetAllocator();
auto param = reinterpret_cast<FillParameter *>(this->op_parameter_);
default_ = param->num_dims_;
std::vector<size_t> img_size;
cl_float4 fill_value = {};
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
auto src_data = out_tensors_[0]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK;
}
int FillOpenCLKernel::RunShape() {
auto allocator_ = ocl_runtime_->GetAllocator();
auto src_data = out_tensors_[0]->data_c();
cl_float4 fill_value = {default_, default_, default_, default_};
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
fill_value.s[0] = in_tensors_[0]->shape()[i];
size_t index = static_cast<size_t>(i);
auto src_origin = cl::array<cl::size_type, 3U>{0, index, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
}
return RET_OK;
}
int FillOpenCLKernel::Init() {
auto param = this->op_parameter_;
if (out_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << " only support dim <= 4";
return RET_ERROR;
}
if (in_tensors_[0]->shape().size() > 1 && param->type_ == PrimitiveType_Fill) {
MS_LOG(ERROR) << " fill only support dim = 1";
return RET_ERROR;
}
return RET_OK;
}
int FillOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto param = this->op_parameter_;
if (param->type_ == PrimitiveType_Fill) {
RunFill();
} else {
RunShape();
}
return RET_OK;
}
kernel::LiteKernel *FillOpenCLKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) FillOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << " new FillOpenCLKernel failed ";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << " Init kernel failed, name: fill ";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Fill, FillOpenCLKernelCreator);
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Shape, FillOpenCLKernelCreator);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Fill, FillOpenCLKernelCreator);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Shape, FillOpenCLKernelCreator);
} // namespace mindspore::kernel

@ -0,0 +1,49 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_
#include <vector>
#include "mindspore/lite/nnacl/fp32/fill.h"
#include "mindspore/lite/nnacl/shape.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
namespace mindspore::kernel {
class FillOpenCLKernel : public OpenCLKernel {
public:
FillOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~FillOpenCLKernel() override = default;
int Init() override;
int Run() override;
private:
int RunFill();
int RunShape();
cl::Kernel kernel_;
private:
float default_{0.0f};
};
} // namespace mindspore::kernel
#endif

@ -0,0 +1,203 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/opencl/kernel/sparse_to_dense.h"
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/cl/sparse_to_dense.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SparseToDense;
namespace mindspore::kernel {
int SparseToDenseOpenCLKernel::InitOutputToDefault() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
cl_float4 fill_value = {};
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
auto src_data = out_tensors_[0]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK;
}
int SparseToDenseOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
auto weight_tensor = in_tensors_[2];
size_t size = 1;
for (int i = 0; i < weight_tensor->shape().size(); ++i) {
size *= weight_tensor->shape()[i];
}
if (weight_scalar_) {
if (weight_tensor->data_type() == kNumberTypeFloat16) {
weight_scalar_ = static_cast<float>(*reinterpret_cast<float16_t *>(weight_tensor->data_c()));
} else {
weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data_c());
}
} else {
auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float);
size_t weight_size = UP_ROUND(size, C4NUM) * sizeof_FLT;
weight_vector_ = allocator->Malloc(weight_size);
allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true);
memset(weight_vector_, 0x00, weight_size);
if (weight_tensor->data_type() == kNumberTypeFloat16) {
if (enable_fp16_) {
memcpy(weight_vector_, weight_tensor->data_c(), size * sizeof_FLT);
} else {
auto weight_fp32 = reinterpret_cast<float *>(weight_vector_);
auto origin_bias_fp16 = reinterpret_cast<float16_t *>(weight_tensor->data_c());
for (int i = 0; i < size; ++i) {
weight_fp32[i] = static_cast<float>(origin_bias_fp16[i]);
}
}
} else {
if (enable_fp16_) {
auto weight_fp16 = reinterpret_cast<float16_t *>(weight_vector_);
auto origin_bias_fp32 = reinterpret_cast<float *>(weight_tensor->data_c());
for (int i = 0; i < size; ++i) {
weight_fp16[i] = static_cast<float16_t>(origin_bias_fp32[i]);
}
} else {
memcpy(weight_vector_, weight_tensor->data_c(), size * sizeof_FLT);
}
}
allocator->UnmapBuffer(weight_vector_);
}
return RET_OK;
}
int SparseToDenseOpenCLKernel::Init() {
if (out_tensors_[0]->shape().size() > 2 || in_tensors_.size() < 3) {
MS_LOG(ERROR) << " only support dim <= 2 and in_tensors_.size >= 3";
return RET_ERROR;
}
if ((in_tensors_[0]->shape()[1] > 3) && (input_dim_ == 2)) {
MS_LOG(ERROR) << "in_tensors_indices shape[1] must be 1 2 or 3 && input_dim_=2 ,but your shapes is: "
<< in_tensors_[0]->shape()[1] << "your input_dim_ is: " << input_dim_;
return ERROR;
}
input_dim_ = in_tensors_[0]->shape().size();
weight_scalar_ = in_tensors_[2]->IsScalar();
std::string kernel_name = "SparseToDense" + std::string(weight_scalar_ ? "ScalarDim" : "VectorDim") +
std::to_string(in_tensors_[0]->shape()[1] == 1 ? 1 : input_dim_);
if (input_dim_ == 2 && in_tensors_[0]->shape()[1] != 1) {
kernel_name += "Shape" + std::to_string(in_tensors_[0]->shape()[1]);
}
std::set<std::string> build_options;
std::string source = sparse_to_dense_source;
std::string program_name = "SparseToDense";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
if (in_tensors_.size() > 3) {
auto input_tensor3 = in_tensors_[3];
if (input_tensor3->data_type() == kNumberTypeFloat16) {
default_ = static_cast<float>(*reinterpret_cast<float16_t *>(input_tensor3->data_c()));
} else {
default_ = *reinterpret_cast<float *>(input_tensor3->data_c());
}
}
InitWeights();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK;
}
int SparseToDenseOpenCLKernel::InferShapeTo4D() {
if (in_tensors_[0]->shape().size() <= 4) {
if (in_tensors_[0]->shape().size() == 1) {
N_ = in_tensors_[0]->shape()[0];
} else if (in_tensors_[0]->shape().size() == 2) {
N_ = in_tensors_[0]->shape()[0];
C_ = in_tensors_[0]->shape()[1];
} else if (in_tensors_[0]->shape().size() == 3) {
N_ = in_tensors_[0]->shape()[0];
W_ = in_tensors_[0]->shape()[1];
C_ = in_tensors_[0]->shape()[2];
} else {
N_ = in_tensors_[0]->shape()[0];
H_ = in_tensors_[0]->shape()[1];
W_ = in_tensors_[0]->shape()[2];
C_ = in_tensors_[0]->shape()[3];
}
} else {
MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size();
return RET_ERROR;
}
return RET_OK;
}
int SparseToDenseOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
InferShapeTo4D();
cl_int2 input_shape = {static_cast<cl_int>(N_ * H_), static_cast<cl_int>(W_ * UP_DIV(C_, C4NUM))};
InitOutputToDefault();
std::vector<size_t> local = {1, 1};
std::vector<size_t> global = {1, 1};
int arg_cn = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
if (weight_scalar_) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_scalar_);
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_vector_);
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, default_);
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}
kernel::LiteKernel *SparseToDenseOpenCLKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (inputs.empty()) {
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
free(opParameter);
return nullptr;
}
auto *kernel = new (std::nothrow) SparseToDenseOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << " new HswishOpenCLKernel failed ";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << " Init kernel failed, name: hswish ";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, SparseToDenseOpenCLKernelCreator);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SparseToDense, SparseToDenseOpenCLKernelCreator);
} // namespace mindspore::kernel

@ -0,0 +1,58 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPARSE_TO_DENSE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPARSE_TO_DENSE_H_
#include <vector>
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "mindspore/lite/nnacl/fp32/sparse_to_dense.h"
namespace mindspore::kernel {
class SparseToDenseOpenCLKernel : public OpenCLKernel {
public:
SparseToDenseOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~SparseToDenseOpenCLKernel() override = default;
int Init() override;
int Run() override;
int InitWeights() override;
private:
int InferShapeTo4D();
int InitOutputToDefault();
private:
cl::Kernel kernel_;
// bool IndicesIsScalar{false};
bool enable_fp16_{false};
float default_{0.0f};
float weight_scalar_{0.f};
void *weight_vector_{nullptr};
int input_dim_{1};
std::vector<int32_t> output_shape_;
size_t N_{1};
size_t H_{1};
size_t W_{1};
size_t C_{1};
};
} // namespace mindspore::kernel
#endif

@ -142,6 +142,7 @@ bool LoadLibraryFromPath(const std::string &library_path, void *handle) {
LOAD_OPENCL_FUNCTION_PTR(clRetainDevice);
LOAD_OPENCL_FUNCTION_PTR(clReleaseDevice);
LOAD_OPENCL_FUNCTION_PTR(clCreateImage);
LOAD_OPENCL_FUNCTION_PTR(clEnqueueFillImage);
#endif
#if CL_HPP_TARGET_OPENCL_VERSION >= 200
LOAD_OPENCL_FUNCTION_PTR(clCreateCommandQueueWithProperties);
@ -228,6 +229,7 @@ CL_DEFINE_FUNC_PTR(clEnqueueCopyImageToBuffer);
CL_DEFINE_FUNC_PTR(clRetainDevice);
CL_DEFINE_FUNC_PTR(clReleaseDevice);
CL_DEFINE_FUNC_PTR(clCreateImage);
CL_DEFINE_FUNC_PTR(clEnqueueFillImage);
#endif
#if CL_HPP_TARGET_OPENCL_VERSION >= 200
CL_DEFINE_FUNC_PTR(clGetKernelSubGroupInfoKHR);
@ -666,6 +668,14 @@ cl_mem clCreateImage(cl_context context, cl_mem_flags flags, const cl_image_form
return func(context, flags, image_format, image_desc, host_ptr, errcode_ret);
}
cl_int clEnqueueFillImage(cl_command_queue command_queue, cl_mem image, const void *fill_color, const size_t *origin,
const size_t *region, cl_uint num_events_in_wait_list, const cl_event *event_wait_list,
cl_event *event) {
auto func = mindspore::lite::opencl::clEnqueueFillImage;
MS_ASSERT(func != nullptr);
return func(command_queue, image, fill_color, origin, region, num_events_in_wait_list, event_wait_list, event);
}
#endif
#if CL_HPP_TARGET_OPENCL_VERSION >= 200

@ -127,6 +127,8 @@ using clRetainDeviceFunc = cl_int (*)(cl_device_id);
using clReleaseDeviceFunc = cl_int (*)(cl_device_id);
using clCreateImageFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, const cl_image_desc *, void *,
cl_int *);
using clEnqueueFillImageFunc = cl_int (*)(cl_command_queue, cl_mem, const void *, const size_t *, const size_t *,
cl_uint, const cl_event *, cl_event *);
#endif
#if CL_HPP_TARGET_OPENCL_VERSION >= 200
using clCreateProgramWithILFunc = cl_program (*)(cl_context, const void *, size_t, cl_int *);
@ -199,6 +201,7 @@ CL_DECLARE_FUNC_PTR(clEnqueueCopyImageToBuffer);
CL_DECLARE_FUNC_PTR(clRetainDevice);
CL_DECLARE_FUNC_PTR(clReleaseDevice);
CL_DECLARE_FUNC_PTR(clCreateImage);
CL_DECLARE_FUNC_PTR(clEnqueueFillImage);
#endif
#if CL_HPP_TARGET_OPENCL_VERSION >= 200
CL_DECLARE_FUNC_PTR(clGetKernelSubGroupInfoKHR);

@ -0,0 +1,145 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h"
using mindspore::lite::Tensor;
using mindspore::schema::PrimitiveType_Fill;
using mindspore::schema::PrimitiveType_Shape;
using mindspore::schema::Format::Format_NHWC;
namespace mindspore {
class TestFillOpenCLCI : public mindspore::CommonTest {
public:
TestFillOpenCLCI() {}
};
TEST_F(TestFillOpenCLCI, Fp32testfill) {
MS_LOG(INFO) << " begin test ";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->Init();
auto allocator = runtime->GetAllocator();
MS_LOG(INFO) << " init tensors ";
std::vector<int> input_shape1 = {2};
float input_data1[] = {3, 3};
float correctOutput[] = {9, 9, 9, 9, 9, 9, 9, 9, 9};
auto data_type = kNumberTypeFloat32;
std::vector<int> output_shape = {3, 3};
auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR);
auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR);
std::vector<lite::Tensor *> inputs{&in_tensor1};
std::vector<lite::Tensor *> outputs{&output_tensor};
MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<FillParameter *>(malloc(sizeof(FillParameter)));
param->num_dims_ = 9;
param->op_parameter_.type_ = PrimitiveType_Fill;
if (param == nullptr) {
MS_LOG(INFO) << " new FillParameter failed ";
return;
}
auto *fill_kernel =
new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (fill_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed ";
delete param;
return;
}
fill_kernel->Init();
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{fill_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
delete param;
delete fill_kernel;
return;
}
// to allocate memory for inputs
in_tensor1.MallocData(allocator);
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1));
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c());
CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001);
delete sub_graph;
}
TEST_F(TestFillOpenCLCI, Fp32testshape) {
MS_LOG(INFO) << " begin test ";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->Init();
auto allocator = runtime->GetAllocator();
MS_LOG(INFO) << " init tensors ";
std::vector<int> input_shape1 = {2, 4};
float input_data1[] = {-0.4045, -0.0924, -0.617, -0.10114, -0.9893, 0.3342, 2.445, -2.182};
float correctOutput[] = {2, 4};
auto data_type = kNumberTypeFloat32;
std::vector<int> output_shape = {2};
auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR);
auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR);
std::vector<lite::Tensor *> inputs{&in_tensor1};
std::vector<lite::Tensor *> outputs{&output_tensor};
MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<ShapeParameter *>(malloc(sizeof(ShapeParameter)));
param->op_parameter_.type_ = PrimitiveType_Shape;
if (param == nullptr) {
MS_LOG(INFO) << " new FillParameter failed ";
return;
}
auto *fill_kernel =
new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (fill_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed ";
delete param;
return;
}
fill_kernel->Init();
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{fill_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
delete param;
delete fill_kernel;
return;
}
// to allocate memory for inputs
in_tensor1.MallocData(allocator);
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1));
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c());
CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001);
delete sub_graph;
}
} // namespace mindspore
Loading…
Cancel
Save