!11438 【MS】【LITE】【GPU】optimize malloc api

From: @wangdongxu6
Reviewed-by: 
Signed-off-by:
pull/11438/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit da92f1affb

@ -63,7 +63,7 @@ __kernel void DepthToSpace(__read_only image2d_t src_data, __write_only image2d_
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H * N
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
if (out_shape.y == 0 || co_size == 0) return;
if (out_shape.y == 0 || block_size == 0) return;
int N = Z / out_shape.y;
int H = Z % out_shape.y;
int co_base = X * C4NUM;

@ -45,7 +45,7 @@ class ActivationOpenCLKernel : public OpenCLKernel {
static std::string GetActTypeString(int act_type);
int type_;
float alpha_;
GpuTensorInfo outShape = GpuTensorInfo(nullptr);
GpuTensorInfo outShape;
};
} // namespace mindspore::kernel

@ -30,6 +30,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::lite::opencl::MemType;
using mindspore::schema::ActivationType_NO_ACTIVATION;
using mindspore::schema::ActivationType_RELU;
@ -45,7 +46,7 @@ int ArithmeticOpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
if (param->broadcasting_ && out_tensors_[0]->shape()[0] > 1) {
if (param->broadcasting_ && out_tensors_.front()->DimensionSize(0) > 1) {
MS_LOG(ERROR) << "Broadcasting don't support N > 1";
return RET_ERROR;
}
@ -63,85 +64,29 @@ int ArithmeticOpenCLKernel::CheckSpecs() {
void ArithmeticOpenCLKernel::SetGlobalLocal() {
if (element_flag_) {
local_size_ = {};
auto out_shape = out_tensors_[0]->shape();
if (out_shape.size() == 2) {
size_t H = out_shape[0];
size_t W = UP_DIV(out_shape[1], C4NUM);
global_size_ = {W, H};
} else {
size_t H = out_shape[0] * out_shape[1];
size_t W = out_shape[2] * UP_DIV(out_shape[3], C4NUM);
global_size_ = {W, H};
}
global_size_ = {out_shape_.width, out_shape_.height};
} else {
local_size_ = {};
auto out_shape = GetNHWCShape(out_tensors_[0]->shape());
global_size_ = {static_cast<size_t>(UP_DIV(out_shape[3], C4NUM)), static_cast<size_t>(out_shape[2]),
static_cast<size_t>(out_shape[1] * out_shape[0])};
global_size_ = {out_shape_.Slice, out_shape_.W, out_shape_.H * out_shape_.N};
}
AlignGlobalLocal(global_size_, local_size_);
AlignGlobalLocal(global_size_, {});
}
int ArithmeticOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
auto fp16_enable = ocl_runtime_->GetFp16Enable();
auto data_size = fp16_enable ? sizeof(float16_t) : sizeof(float);
for (auto in_tensor_ : in_tensors_) {
auto nhwc_shape = GetNHWCShape(in_tensor_->shape());
inputs_nhwc_shapes_.push_back(nhwc_shape);
if (!in_tensor_->IsConst()) {
inputs_weight_ptrs_.push_back(nullptr);
for (int i = 0; i < 2; ++i) {
const auto &in_tensor = in_tensors_.at(i);
GpuTensorInfo *in_shape = (i == 0) ? &in0_shape_ : &in1_shape_;
if (in_tensor->IsConst()) {
std::vector<char> weight(in_shape->Image2DSize, 0);
bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16;
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, *in_shape);
size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{in_shape->width, in_shape->height, dtype};
auto weight_ptr_ = allocator->Malloc(img_size, weight.data());
weight_ptrs_.push_back(weight_ptr_);
} else {
auto allocator = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size = GetImage2dShapeFromNHWC(nhwc_shape, schema::Format_NHWC4);
int pack_weight_size = img_size[0] * img_size[1] * C4NUM;
int plane = nhwc_shape[1] * nhwc_shape[2];
int channel = nhwc_shape[3];
int batch = nhwc_shape[0];
img_size.push_back(fp16_enable ? CL_HALF_FLOAT : CL_FLOAT);
if (!fp16_enable) {
float *weight = new (std::nothrow) float[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNHWCToNHWC4<float, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); };
PackNHWCToNHWC4<float16_t, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
inputs_weight_ptrs_.push_back(weight_ptr_);
delete[] weight;
} else {
float16_t *weight = new (std::nothrow) float16_t[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNHWCToNHWC4<float, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float16_t(float16_t)> to_dtype = [](float16_t x) -> float16_t { return x; };
PackNHWCToNHWC4<float16_t, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
inputs_weight_ptrs_.push_back(weight_ptr_);
delete[] weight;
}
weight_ptrs_.push_back(nullptr);
}
}
return RET_OK;
@ -150,21 +95,21 @@ int ArithmeticOpenCLKernel::InitWeights() {
void ArithmeticOpenCLKernel::SetConstArgs() {
int arg_idx = 3;
if (!element_flag_) {
cl_int4 input0_shape = {inputs_nhwc_shapes_[0][0], inputs_nhwc_shapes_[0][1], inputs_nhwc_shapes_[0][2],
UP_DIV(inputs_nhwc_shapes_[0][3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input0_shape);
cl_int4 input1_shape = {inputs_nhwc_shapes_[1][0], inputs_nhwc_shapes_[1][1], inputs_nhwc_shapes_[1][2],
UP_DIV(inputs_nhwc_shapes_[1][3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input1_shape);
auto out_shape = GetNHWCShape(out_tensors_[0]->shape());
cl_int4 output_shape{out_shape[0], out_shape[1], out_shape[2], UP_DIV(out_shape[3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
cl_int4 in0_shape = {static_cast<int>(in0_shape_.N), static_cast<int>(in0_shape_.H), static_cast<int>(in0_shape_.W),
static_cast<int>(in0_shape_.Slice)};
cl_int4 in1_shape = {static_cast<int>(in1_shape_.N), static_cast<int>(in1_shape_.H), static_cast<int>(in1_shape_.W),
static_cast<int>(in1_shape_.Slice)};
cl_int4 out_shape = {static_cast<int>(out_shape_.N), static_cast<int>(out_shape_.H), static_cast<int>(out_shape_.W),
static_cast<int>(out_shape_.Slice)};
int broadcastC_flag = 0; // do not need broadcast in C4
if (inputs_nhwc_shapes_[0][3] == 1 && inputs_nhwc_shapes_[1][3] != 1) {
if (in0_shape_.C == 1 && in1_shape_.C != 1) {
broadcastC_flag = 1; // BroadCast C4 in input0
} else if (inputs_nhwc_shapes_[0][3] != 1 && inputs_nhwc_shapes_[1][3] == 1) {
} else if (in0_shape_.C != 1 && in1_shape_.C == 1) {
broadcastC_flag = 2; // BroadCast C4 in input1
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in0_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in1_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, broadcastC_flag);
} else {
cl_int2 output_shape{static_cast<int>(global_range_[0]), static_cast<int>(global_range_[1])};
@ -175,11 +120,14 @@ void ArithmeticOpenCLKernel::SetConstArgs() {
}
int ArithmeticOpenCLKernel::Prepare() {
lite::STATUS error_code = RET_OK;
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name_);
#else
in0_shape_ = GpuTensorInfo(in_tensors_[0]);
in1_shape_ = GpuTensorInfo(in_tensors_[1]);
out_shape_ = GpuTensorInfo(out_tensors_[0]);
auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
if (Type() == PrimitiveType_BiasAdd) {
const_cast<ArithmeticParameter *>(param)->broadcasting_ = true;
@ -197,7 +145,7 @@ int ArithmeticOpenCLKernel::Prepare() {
std::string program_name = "Arithmetic";
std::string source = arithmetic_source;
ocl_runtime_->LoadSource(program_name, source);
error_code = ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name_);
int error_code = ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name_);
#endif
if (error_code != RET_OK) {
return error_code;
@ -212,11 +160,10 @@ int ArithmeticOpenCLKernel::Prepare() {
int ArithmeticOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
auto input_0_ptr = weight_ptrs_[0] == nullptr ? in_tensors_[0]->data_c() : weight_ptrs_[0];
auto input_1_ptr = weight_ptrs_[1] == nullptr ? in_tensors_[1]->data_c() : weight_ptrs_[1];
int arg_idx = 0;
auto input_0_ptr = inputs_weight_ptrs_[0] == nullptr ? in_tensors_[0]->data_c() : inputs_weight_ptrs_[0];
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_0_ptr);
auto input_1_ptr = inputs_weight_ptrs_[1] == nullptr ? in_tensors_[1]->data_c() : inputs_weight_ptrs_[1];
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_1_ptr);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);

@ -43,8 +43,10 @@ class ArithmeticOpenCLKernel : public OpenCLKernel {
bool element_flag_{true};
float activation_min_{-FLT_MAX};
float activation_max_{FLT_MAX};
std::vector<std::vector<int>> inputs_nhwc_shapes_;
std::vector<void *> inputs_weight_ptrs_;
GpuTensorInfo in0_shape_;
GpuTensorInfo in1_shape_;
GpuTensorInfo out_shape_;
std::vector<void *> weight_ptrs_;
std::string kernel_name_;
};
} // namespace mindspore::kernel

@ -18,7 +18,6 @@
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/cl/concat.cl.inc"
@ -27,22 +26,23 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Concat;
namespace mindspore::kernel {
int ConcatOpenCLKernel::RunAxis0() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
ImageSize img_size;
auto dst_data = out_tensors_[0]->data_c();
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
auto *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
for (int i = 0; i < in_tensors_.size(); i++) {
auto src_data = inputs_weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : inputs_weight_ptrs_.at(i);
auto src_data = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : weight_ptrs_.at(i);
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
cl::Image2D *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
auto *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region);
dst_origin[1] += region[1];
}
@ -75,8 +75,8 @@ int ConcatOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << " GPU Unsupported shape.size > 4 ";
return RET_ERROR;
}
for (int i = 0; i < in_tensors_.size(); ++i) {
auto in_tensors_shape_size = in_tensors_[i]->shape().size();
for (auto &in_tensor : in_tensors_) {
auto in_tensors_shape_size = in_tensor->shape().size();
if (in_tensors_shape_size > 4) {
MS_LOG(ERROR) << " GPU Unsupported in_tensor shape.size > 4 ";
return RET_ERROR;
@ -109,7 +109,7 @@ int ConcatOpenCLKernel::CheckSpecs() {
void ConcatOpenCLKernel::SetConstArgs() {
GpuTensorInfo img_info(out_tensors_[0]);
size_t dtype = enable_fp16_ ? sizeof(cl_half) : sizeof(cl_float);
size_t dtype = ocl_runtime_->GetFp16Enable() ? sizeof(cl_half) : sizeof(cl_float);
stride_w = img_info.RowPitch() / dtype;
cl_int4 output_shape_ = {};
for (int i = 0; i < out_tensors_[0]->shape().size(); ++i) {
@ -118,22 +118,22 @@ void ConcatOpenCLKernel::SetConstArgs() {
Broadcast2GpuShape(out_shape_.s, output_shape_.s, out_tensors_[0]->shape().size(), 1);
int arg_cn = in_tensors_.size() + 1;
if (axis_ == 3 && !Align_) {
for (int i = 0; i < in_tensors_.size(); ++i) {
for (auto &in_tensor : in_tensors_) {
cl_int4 temp = {};
for (int j = 0; j < in_tensors_[i]->shape().size(); ++j) {
temp.s[j] = in_tensors_[i]->shape()[j];
for (int j = 0; j < in_tensor->shape().size(); ++j) {
temp.s[j] = in_tensor->shape()[j];
}
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensors_[i]->shape().size(), 1);
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensor->shape().size(), 1);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_shape_);
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, stride_w);
} else {
for (int i = 0; i < in_tensors_.size(); ++i) {
for (auto &in_tensor : in_tensors_) {
cl_int4 temp = {};
for (int j = 0; j < in_tensors_[i]->shape().size(); ++j) {
temp.s[j] = in_tensors_[i]->shape()[j];
for (int j = 0; j < in_tensor->shape().size(); ++j) {
temp.s[j] = in_tensor->shape()[j];
}
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensors_[i]->shape().size(), 1);
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensor->shape().size(), 1);
in_shape_.s[3] = UP_DIV(in_shape_.s[3], C4NUM);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_shape_);
}
@ -160,84 +160,36 @@ void ConcatOpenCLKernel::SetGlobalLocal() {
OpenCLKernel::AlignGlobalLocal(global_size_, local_size_);
}
int ConcatOpenCLKernel::ConvertWeightToTensor(const std::vector<lite::Tensor *> &in_tensors,
std::vector<void *> *inputs_weight_ptrs, bool fp16_enable,
size_t data_size) {
for (auto in_tensor_ : in_tensors) {
auto nhwc_shape = GetNHWCShape(in_tensor_->shape());
if (!in_tensor_->IsConst()) {
(*inputs_weight_ptrs).push_back(nullptr);
int ConcatOpenCLKernel::ConvertWeightToTensor() {
auto allocator = ocl_runtime_->GetAllocator();
bool fp16_enable = ocl_runtime_->GetFp16Enable();
for (auto in_tensor : in_tensors_) {
auto in_shape = GpuTensorInfo(in_tensor);
if (in_tensor->IsConst()) {
std::vector<char> weight(in_shape.Image2DSize, 0);
bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16;
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, in_shape);
size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{in_shape.width, in_shape.height, dtype};
auto weight_ptr_ = allocator->Malloc(img_size, weight.data());
weight_ptrs_.push_back(weight_ptr_);
} else {
auto allocator = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size = GetImage2dShapeFromNHWC(nhwc_shape, schema::Format_NHWC4);
int pack_weight_size = img_size[0] * img_size[1] * C4NUM;
int plane = nhwc_shape[1] * nhwc_shape[2];
int channel = nhwc_shape[3];
int batch = nhwc_shape[0];
img_size.push_back(fp16_enable ? CL_HALF_FLOAT : CL_FLOAT);
if (!fp16_enable) {
float *weight = new (std::nothrow) float[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNHWCToNHWC4<float, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); };
PackNHWCToNHWC4<float16_t, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
(*inputs_weight_ptrs).push_back(weight_ptr_);
delete[] weight;
} else {
float16_t *weight = new (std::nothrow) float16_t[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNHWCToNHWC4<float, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float16_t(float16_t)> to_dtype = [](float16_t x) -> float16_t { return x; };
PackNHWCToNHWC4<float16_t, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
(*inputs_weight_ptrs).push_back(weight_ptr_);
delete[] weight;
}
weight_ptrs_.push_back(nullptr);
}
}
return RET_OK;
}
int ConcatOpenCLKernel::Prepare() {
enable_fp16_ = ocl_runtime_->GetFp16Enable();
auto data_size = enable_fp16_ ? sizeof(float16_t) : sizeof(float);
ConvertWeightToTensor(in_tensors_, &inputs_weight_ptrs_, enable_fp16_, data_size);
ConvertWeightToTensor();
if (axis_ == 0) {
for (int i = 0; i < in_tensors_.size(); ++i) {
if (in_tensors_.at(i)->shape().size() != 1) {
return RET_OK;
}
if (std::any_of(in_tensors_.begin(), in_tensors_.end(), [](lite::Tensor *t) { return t->shape().size() != 1; })) {
return RET_OK;
}
axis_ = 3;
}
for (int i = 0; i < in_tensors_.size(); ++i) {
int length = in_tensors_[0]->shape().size();
if (in_tensors_[i]->shape()[length - 1] % C4NUM != 0) {
for (auto const &in_tensor : in_tensors_) {
if (in_tensor->shape().back() % C4NUM != 0) {
Align_ = false;
}
}
@ -268,7 +220,7 @@ int ConcatOpenCLKernel::Run() {
}
int arg_cn = 0;
for (int i = 0; i < in_tensors_.size(); ++i) {
auto input_ptr = inputs_weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : inputs_weight_ptrs_.at(i);
auto input_ptr = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : weight_ptrs_.at(i);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_ptr);
}
if (axis_ == 3 && !Align_) {

@ -43,8 +43,7 @@ class ConcatOpenCLKernel : public OpenCLKernel {
uint32_t OC = {1};
std::vector<size_t> global;
bool Align_{true};
std::vector<void *> inputs_weight_ptrs_;
bool enable_fp16_{false};
std::vector<void *> weight_ptrs_;
cl_int stride_w{1};
cl_int4 in_shape_{};
cl_int4 out_shape_{};
@ -52,8 +51,7 @@ class ConcatOpenCLKernel : public OpenCLKernel {
private:
int RunAxis0();
int ConvertWeightToTensor(const std::vector<lite::Tensor *> &in_tensors, std::vector<void *> *inputs_weight_ptrs,
bool fp16_enable, size_t data_size);
int ConvertWeightToTensor();
};
} // namespace mindspore::kernel

@ -255,7 +255,7 @@ void Conv2DOpenCLKernel::InitFilter() {
size_t height = KH_ * KW_ * UP_ROUND(CI_, CI_TILE);
size_t dtype = use_fp16_ ? CL_HALF_FLOAT : CL_FLOAT;
size = width * height * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size, {width, height, dtype});
packed_filter_ = allocator->Malloc({width, height, dtype});
} else {
size = UP_DIV(CO_SLICES_, Ogroup) * KH_ * KW_ * CI_SLICES_ * Ogroup * CI_TILE * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size);

@ -28,6 +28,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_RELU6;
using mindspore::schema::PrimitiveType_DeConv2D;
@ -193,8 +194,8 @@ int Conv2dTransposeOpenCLKernel::InitWeights() {
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size);
ImageSize img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(img_size);
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
memset(bias_, 0x00, div_co * C4NUM * data_size);
if (in_tensors_.size() == 3) {

@ -37,6 +37,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::lite::opencl::MemType;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
@ -61,6 +62,7 @@ int DepthwiseConv2dOpenCLKernel::CheckSpecs() {
}
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::Prepare() {
std::string kernel_name = "DepthwiseConv2d";
if (out_mem_type_ == MemType::BUF) {
@ -114,13 +116,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
int plane_in = parameter->kernel_h_ * parameter->kernel_w_;
int plane_out = plane_in * C4NUM;
std::vector<size_t> img_size;
if (filter_type_ == MemType::IMG) {
int alignment = ocl_runtime_->GetImagePitchAlignment();
plane_out = UP_ROUND(plane_out, alignment) * C4NUM;
pack_weight_size = plane_out * CO4;
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
img_size = {(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype};
}
pack_weight_size = pack_weight_size * dtype_size;
auto ConvertFilter = [](void *src, void *dst, TypeId src_type, TypeId dst_type, size_t plane_in, size_t plane_out,
@ -153,7 +152,13 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
auto src_type = in_tensors_.at(kWeightIndex)->data_type();
auto dst_type = is_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32;
ConvertFilter(origin_weight, temp_filter.data(), src_type, dst_type, plane_in, plane_out, out_info.C);
packed_weight_ = allocator->Malloc(pack_weight_size, img_size, temp_filter.data());
if (filter_type_ == MemType::IMG) {
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype};
packed_weight_ = allocator->Malloc(img_size, temp_filter.data());
} else {
packed_weight_ = allocator->Malloc(pack_weight_size, temp_filter.data());
}
FreeDequantedWeight();
if (packed_weight_ == nullptr) {
return RET_ERROR;
@ -182,12 +187,13 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
auto element_size = in_tensors_.at(kBiasIndex)->ElementsNum();
ConvertBias(in_tensors_.at(kBiasIndex)->data_c(), temp_bias.data(), element_size, dtype_size, src_type, dst_type);
}
bias_data_ = allocator->Malloc(bias_size, {}, temp_bias.data());
bias_data_ = allocator->Malloc(bias_size, temp_bias.data());
if (bias_data_ == nullptr) {
return RET_ERROR;
}
return mindspore::lite::RET_OK;
}
void DepthwiseConv2dOpenCLKernel::SetConstArgs() {
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_);
auto in_info = GpuTensorInfo(in_tensors_[0]);
@ -216,6 +222,7 @@ void DepthwiseConv2dOpenCLKernel::SetConstArgs() {
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, relu_clips[parameter->act_type_].first);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, relu_clips[parameter->act_type_].second);
}
void DepthwiseConv2dOpenCLKernel::SetGlobalLocal() {
auto out_info = GpuTensorInfo(out_tensors_[0]);
// set global

@ -26,6 +26,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Fill;
using mindspore::schema::PrimitiveType_Shape;
@ -35,13 +36,13 @@ int FillOpenCLKernel::RunFill() {
auto allocator_ = ocl_runtime_->GetAllocator();
auto param = reinterpret_cast<FillParameter *>(this->op_parameter_);
default_ = param->num_dims_;
std::vector<size_t> img_size;
ImageSize img_size;
cl_float4 fill_value = {};
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
auto src_data = out_tensors_[0]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK;

@ -29,6 +29,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_RELU6;
using mindspore::schema::ActivationType_TANH;
@ -211,8 +212,8 @@ int FullConnectionOpenCLKernel::InitBias() {
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * dtype_size, img_size);
ImageSize img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(img_size);
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
memset(bias_, 0x00, co4 * C4NUM * dtype_size);
if (in_tensors_.size() == 3) {

@ -27,6 +27,8 @@
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::kernel {
@ -55,7 +57,7 @@ int MatMulOpenCLKernel::CheckSpecs() {
transposeA = param->a_transpose_;
if (transposeA) {
MS_LOG(ERROR) << "matmul only support a_transpose_=false yet.";
return mindspore::lite::RET_ERROR;
return RET_ERROR;
}
transposeB = param->b_transpose_;
act_weight_ = !in_tensors_[1]->IsConst();
@ -63,7 +65,7 @@ int MatMulOpenCLKernel::CheckSpecs() {
if (in_tensors_[0]->shape().size() != out_tensors_[0]->shape().size() || in_tensors_[0]->shape().size() < 2 ||
in_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4.";
return mindspore::lite::RET_ERROR;
return RET_ERROR;
}
return RET_OK;
}
@ -100,7 +102,7 @@ int MatMulOpenCLKernel::Prepare() {
SetConstArgs();
SetGlobalLocal();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return mindspore::lite::RET_OK;
return RET_OK;
}
int MatMulOpenCLKernel::InitWeights() {
@ -207,7 +209,7 @@ int MatMulOpenCLKernel::Run() {
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[1]->data_c());
}
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return mindspore::lite::RET_OK;
return RET_OK;
}
kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::Tensor *> &inputs,
@ -244,7 +246,7 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::Tensor *>
return kernel;
}
auto ret = kernel->CheckSpecs();
if (ret != mindspore::lite::RET_OK) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!";
delete kernel;
return nullptr;

@ -22,7 +22,6 @@
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/common/utils.h"
#include "nnacl/matmul_parameter.h"
#define MAXDEPTH 5
namespace mindspore::kernel {

@ -41,8 +41,8 @@ class OneHotOpenCLKernel : public OpenCLKernel {
float on_value_{1.0f};
float off_value_{0.0f};
int axis_{0};
GpuTensorInfo in_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo out_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo in_shape_;
GpuTensorInfo out_shape_;
};
} // namespace mindspore::kernel

@ -39,7 +39,7 @@ class ReduceOpenCLKernel : public OpenCLKernel {
private:
cl_float4 GenC4Mask();
static std::string GetReduceTypeStr(int type);
GpuTensorInfo outShape = GpuTensorInfo(nullptr);
GpuTensorInfo outShape;
bool use_local_{false};
bool wc_reduce_{false};
static const size_t LOCAL_CACHE_THREAD{16};

@ -30,6 +30,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::lite::opencl::MemType;
using mindspore::schema::PrimitiveType_Scale;
@ -64,87 +65,55 @@ void ScaleOpenCLKernel::Image2dGetWorkGroupSize() {
}
int ScaleOpenCLKernel::InitWeights() {
if (!weight_vector_flag_) {
auto *in_tensor = in_tensors_[0];
auto *scale_tensor = in_tensors_[1];
auto *offset_tensor = in_tensors_[2];
auto scale_dtype = scale_tensor->data_type();
if (!weight_vector_flag_ || !scale_tensor->IsConst()) {
return RET_OK;
}
if (in_tensors_[1]->IsConst()) {
auto allocator = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
GetImageSize(0, &img_size);
img_size[2] = in_tensors_[1]->data_type() == kNumberTypeFloat16 ? CL_HALF_FLOAT : CL_FLOAT;
if (broadcast_flag_) {
img_size[1] = 1;
img_size[0] = UP_DIV(in_tensors_[1]->shape()[0], C4NUM);
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, in_tensors_[1]->data_c());
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, in_tensors_[2]->data_c());
return RET_OK;
auto allocator = ocl_runtime_->GetAllocator();
auto fp16_enable = ocl_runtime_->GetFp16Enable();
ImageSize img_size;
GetImageSize(0, &img_size);
img_size.dtype = scale_dtype == kNumberTypeFloat16 ? CL_HALF_FLOAT : CL_FLOAT;
if (broadcast_flag_) {
img_size.height = 1;
img_size.width = UP_DIV(scale_tensor->shape()[0], C4NUM);
scale_ptr_ = allocator->Malloc(img_size, scale_tensor->data_c());
offset_ptr_ = allocator->Malloc(img_size, offset_tensor->data_c());
return RET_OK;
}
if (in_tensor->format() == scale_tensor->format()) {
if (in_tensor->data_type() == scale_tensor->data_type()) {
scale_ptr_ = allocator->Malloc(img_size, scale_tensor->data_c());
offset_ptr_ = allocator->Malloc(img_size, offset_tensor->data_c());
} else {
MS_LOG(ERROR) << "Unsupported data type transpose from " << scale_tensor->data_type() << "to "
<< in_tensor->data_type();
return RET_ERROR;
}
auto image2d_info = GpuTensorInfo(in_tensors_[1]);
int pack_weight_size = image2d_info.ElementsC4Num;
int plane = image2d_info.H * image2d_info.W;
int channel = image2d_info.C;
int batch = image2d_info.N;
if (in_tensors_[0]->format() == in_tensors_[1]->format()) {
if (in_tensors_[0]->data_type() == in_tensors_[1]->data_type()) {
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, in_tensors_[1]->data_c());
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, in_tensors_[2]->data_c());
} else {
MS_LOG(ERROR) << "Unsupport data type transpose from " << in_tensors_[1]->data_type() << "to "
<< in_tensors_[0]->data_type();
return RET_ERROR;
}
} else if (in_tensors_[0]->format() == schema::Format_NHWC) {
if (in_tensors_[1]->format() == schema::Format_NHWC) {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
auto *scale = new (std::nothrow) float[pack_weight_size];
if (scale == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
auto *offset = new (std::nothrow) float[pack_weight_size];
if (offset == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
delete[] scale;
return RET_ERROR;
}
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNHWCToNHWC4<float, float>(in_tensors_[1]->data_c(), scale, batch, plane, channel, to_dtype);
PackNHWCToNHWC4<float, float>(in_tensors_[2]->data_c(), offset, batch, plane, channel, to_dtype);
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, scale);
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, offset);
delete[] scale;
delete[] offset;
} else if (in_tensors_[0]->data_type() == kNumberTypeFloat16) {
auto *scale = new (std::nothrow) float16_t[pack_weight_size];
if (scale == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
auto *offset = new (std::nothrow) float16_t[pack_weight_size];
if (offset == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
delete[] scale;
return RET_ERROR;
}
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNHWCToNHWC4<float, float16_t>(in_tensors_[1]->data_c(), scale, batch, plane, channel, to_dtype);
PackNHWCToNHWC4<float, float16_t>(in_tensors_[2]->data_c(), offset, batch, plane, channel, to_dtype);
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, scale);
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, offset);
delete[] scale;
delete[] offset;
} else {
MS_LOG(ERROR) << "Unsupport data type transpose from " << in_tensors_[1]->data_type() << "to "
<< in_tensors_[0]->data_type();
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Unsupport format transpose from " << in_tensors_[1]->format() << "to "
<< in_tensors_[0]->format();
return RET_ERROR;
}
} else if (in_tensor->format() == schema::Format_NHWC && scale_tensor->format() == schema::Format_NHWC) {
if (scale_dtype == kNumberTypeFloat32 || scale_dtype == kNumberTypeFloat16) {
auto image2d_info = GpuTensorInfo(scale_tensor);
int pack_weight_size = image2d_info.ElementsC4Num;
std::vector<char> scale(pack_weight_size, 0);
std::vector<char> offset(pack_weight_size, 0);
bool src_is_fp16 = scale_dtype == kNumberTypeFloat16;
PackNHWCToNHWC4(scale_tensor->data_c(), scale.data(), src_is_fp16, fp16_enable, image2d_info);
PackNHWCToNHWC4(offset_tensor->data_c(), offset.data(), src_is_fp16, fp16_enable, image2d_info);
scale_ptr_ = allocator->Malloc(img_size, scale.data());
offset_ptr_ = allocator->Malloc(img_size, offset.data());
} else {
MS_LOG(ERROR) << "Unsupported data type transpose from " << scale_tensor->data_type() << "to "
<< in_tensor->data_type();
return RET_ERROR;
}
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupported format transpose from " << scale_tensor->format() << "to " << in_tensor->format();
return RET_ERROR;
}
return RET_OK;
}
@ -231,7 +200,7 @@ int ScaleOpenCLKernel::Run() {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, static_cast<float>(scale));
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, static_cast<float>(offset));
} else {
MS_LOG(ERROR) << "Unsupport data type " << in_tensors_[1]->data_type();
MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[1]->data_type();
return RET_ERROR;
}
}

@ -52,7 +52,7 @@ class SoftmaxOpenCLKernel : public OpenCLKernel {
std::vector<size_t> local_size_;
std::vector<size_t> global_size_;
int axis_{0};
GpuTensorInfo out_shape = GpuTensorInfo(nullptr);
GpuTensorInfo out_shape;
};
} // namespace mindspore::kernel

@ -36,8 +36,8 @@ class SpaceToDepthOpenCLKernel : public OpenCLKernel {
void SetGlobalLocal() override;
private:
GpuTensorInfo in_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo out_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo in_shape_;
GpuTensorInfo out_shape_;
};
} // namespace mindspore::kernel

@ -27,19 +27,20 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_SparseToDense;
namespace mindspore::kernel {
int SparseToDenseOpenCLKernel::InitOutputToDefault() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
ImageSize img_size;
cl_float4 fill_value = {};
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
auto src_data = out_tensors_[0]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK;
@ -113,7 +114,7 @@ int SparseToDenseOpenCLKernel::CheckSpecs() {
}
auto param = reinterpret_cast<SparseToDenseParameter *>(op_parameter_);
if (param->validate_indices_) {
MS_LOG(ERROR) << "Unspported unordered for in_tensors_indices";
MS_LOG(ERROR) << "Unsupported unordered for in_tensors_indices";
return RET_ERROR;
}
return RET_OK;

@ -26,12 +26,13 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Split;
namespace mindspore::kernel {
int SplitOpenCLKernel::RunAxis0() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
auto src_data = in_tensors_[0]->data_c();
cl::Image2D *in_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
if (in_image == nullptr) {
@ -41,9 +42,10 @@ int SplitOpenCLKernel::RunAxis0() {
auto src_area = cl::array<cl::size_type, 3U>{0, 0, 0};
for (int i = 0; i < out_tensors_.size(); i++) {
auto dst_data = out_tensors_[i]->data_c();
ImageSize img_size;
allocator_->GetImageSize(dst_data, &img_size);
auto dst_area = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
if (out_image == nullptr) {
MS_LOG(ERROR) << "RunAxis0 out_image can not be nullptr";

@ -25,13 +25,14 @@
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Stack;
namespace mindspore::kernel {
int StackOpenCLKernel::RunAxis0() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
ImageSize img_size;
auto dst_data = out_tensors_[0]->data_c();
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
@ -39,7 +40,7 @@ int StackOpenCLKernel::RunAxis0() {
auto src_data = in_tensors_[i]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region);
dst_origin[1] += region[1];

@ -20,19 +20,14 @@
#include "src/runtime/kernel/opencl/kernel/matmul.h"
#include "src/runtime/kernel/opencl/kernel/strassen.h"
#include "src/common/utils.h"
#ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/strassen.cl.inc"
#endif
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
namespace mindspore::kernel {
int StrassenOpenCLKernel::Prepare() {
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
std::string kernel_name = "MatMul_Strassen_NHWC4_2d";
std::string source = strassen_source;
std::string program_name = "MatMul";
@ -43,8 +38,6 @@ int StrassenOpenCLKernel::Prepare() {
ocl_runtime_->BuildKernel(kernel_back_result, program_name, "Strassen_Back_Result");
ocl_runtime_->BuildKernel(MatMul_StrassenBUFFilled, program_name, "MatMul_BUF_Filled");
ocl_runtime_->BuildKernel(MatMul_StrassenIMGFilled, program_name, "MatMul_IMG_Filled");
#endif
auto ret = InitWeights();
if (ret != RET_OK) {
return ret;
@ -52,31 +45,25 @@ int StrassenOpenCLKernel::Prepare() {
SetConstArgs();
SetGlobalLocal();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return mindspore::lite::RET_OK;
return RET_OK;
}
void StrassenOpenCLKernel::AllocatorMemoryForStrassen(int NumA, int NumB) {
std::vector<size_t> img_size;
img_size.push_back(UP_DIV(NumA, C4NUM));
img_size.push_back(NumA);
auto allocator = ocl_runtime_->GetAllocator();
size_t img_dtype = enable_fp16_ ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{static_cast<size_t>(UP_DIV(NumA, C4NUM)), static_cast<size_t>(NumA), img_dtype};
size_t dtype_size = enable_fp16_ ? sizeof(cl_half) : sizeof(cl_float);
img_size.push_back(img_dtype);
auto allocator = ocl_runtime_->GetAllocator();
size_t memA = NumA * NumA;
size_t memB = NumB * NumB * dtype_size;
for (int depth = 0; depth < MAXDEPTH; depth++) {
B_temp[depth] = allocator->Malloc(memB);
A_temp[depth] = allocator->Malloc(memA, img_size);
M1[depth] = allocator->Malloc(memA, img_size);
M2[depth] = allocator->Malloc(memA, img_size);
M3[depth] = allocator->Malloc(memA, img_size);
M4[depth] = allocator->Malloc(memA, img_size);
M5[depth] = allocator->Malloc(memA, img_size);
M6[depth] = allocator->Malloc(memA, img_size);
M7[depth] = allocator->Malloc(memA, img_size);
A_temp[depth] = allocator->Malloc(img_size);
M1[depth] = allocator->Malloc(img_size);
M2[depth] = allocator->Malloc(img_size);
M3[depth] = allocator->Malloc(img_size);
M4[depth] = allocator->Malloc(img_size);
M5[depth] = allocator->Malloc(img_size);
M6[depth] = allocator->Malloc(img_size);
M7[depth] = allocator->Malloc(img_size);
}
}
@ -333,6 +320,6 @@ int StrassenOpenCLKernel::Run() {
}
DoStrassen(in_tensors_.at(0)->data_c(), padWeight_, out_tensors_.at(0)->data_c(), in_tensors_.at(0)->shape()[0], 0,
threshold);
return mindspore::lite::RET_OK;
return RET_OK;
}
} // namespace mindspore::kernel

@ -21,6 +21,8 @@
#include <vector>
#include "src/runtime/kernel/opencl/kernel/matmul.h"
#define MAXDEPTH 5
namespace mindspore::kernel {
class StrassenOpenCLKernel : public MatMulOpenCLKernel {

@ -98,7 +98,7 @@ void WinogradOpenCLKernel::InitFilter() {
size_t height = CO_SLICES_;
size_t dtype = use_fp16_ ? CL_HALF_FLOAT : CL_FLOAT;
size = width * height * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size, {width, height, dtype});
packed_filter_ = allocator->Malloc({width, height, dtype});
} else {
size = UP_DIV(CO_SLICES_, Ogroup) * 6 * 6 * CI_SLICES_ * Ogroup * CI_TILE * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size);
@ -136,11 +136,11 @@ void WinogradOpenCLKernel::AllocateMemory() {
size_t width = TILE_HW_;
size_t height = CI_SLICES_ * 36;
winograd_mem0_ = allocator->Malloc(width * height * sizeof_FLT_, {width, height, img_dtype});
winograd_mem0_ = allocator->Malloc({width, height, img_dtype});
width = TILE_HW_;
height = CO_SLICES_ * 36;
winograd_mem1_ = allocator->Malloc(width * height * sizeof_FLT_, {width, height, img_dtype});
winograd_mem1_ = allocator->Malloc({width, height, img_dtype});
}
void WinogradOpenCLKernel::SetConstArgs() {

@ -19,6 +19,7 @@
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
namespace mindspore::kernel {
@ -60,7 +61,7 @@ int OpenCLKernel::AlignGlobalLocal(const std::vector<size_t> &global, const std:
return RET_OK;
}
int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) {
MS_ASSERT(img_size);
if (idx >= out_tensors_.size()) {
return RET_ERROR;
@ -133,13 +134,13 @@ int OpenCLKernel::PreProcess() {
auto *output = out_tensors_.at(i);
MS_ASSERT(output);
if (GetMemType() == lite::opencl::MemType::IMG) {
std::vector<size_t> img_size;
ImageSize img_size;
ret = GetImageSize(i, &img_size);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetImageSize failed";
return ret;
}
auto data_ptr = allocator->Malloc(output->Size(), img_size);
auto data_ptr = allocator->Malloc(img_size);
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
return RET_ERROR;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save