!4337 [MS][LITE] opencl operator `softmax` support iamge2d

Merge pull request !4337 from chenzhongming/lite
pull/4337/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 55ef18a8d8

@ -32,6 +32,10 @@ enum Format : int {
CKHW,
KHWC,
CHWK,
HW,
HW4,
NC,
NC4,
NC4HW4 = 100,
NUM_OF_FORMAT
}

@ -104,7 +104,7 @@ int Executor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format d
allocator->Free(src_data);
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in float32";
return RET_ERROR;
}
@ -116,7 +116,7 @@ int Executor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format
MS_ASSERT(4 == tensor->shape().size());
// auto src_format = tensor->GetFormat();
// todo
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in uint8";
return RET_ERROR;
}

@ -104,8 +104,8 @@ bool Tensor::operator==(const Value &other) const {
}
int32_t Tensor::Batch() const {
if (this->shape_.size() != 4) {
MS_LOG(ERROR) << "tensor should have 4 dim";
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
return -1;
}
switch (this->format_) {
@ -115,6 +115,8 @@ int32_t Tensor::Batch() const {
case schema::Format_NC4HW4:
case schema::Format_KCHW:
case schema::Format_KHWC:
case schema::Format_NC:
case schema::Format_NC4:
return this->shape_[0];
case schema::Format_HWCK:
case schema::Format_CHWK:
@ -124,19 +126,21 @@ int32_t Tensor::Batch() const {
case schema::Format_CKHW:
return this->shape_[1];
default:
MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_);
MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_);
return -1;
}
}
int32_t Tensor::Channel() const {
if (this->shape_.size() != 4) {
MS_LOG(ERROR) << "tensor should have 4 dim";
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
return -1;
}
switch (this->format_) {
case schema::Format_NCHW:
case schema::Format_KCHW:
case schema::Format_NC:
case schema::Format_NC4:
return this->shape_[1];
case schema::Format_HWCK:
return this->shape_[2];
@ -155,8 +159,8 @@ int32_t Tensor::Channel() const {
}
int32_t Tensor::Height() const {
if (this->shape_.size() != 4) {
MS_LOG(ERROR) << "tensor should have 4 dim";
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
return -1;
}
switch (this->format_) {
@ -172,16 +176,18 @@ int32_t Tensor::Height() const {
return this->shape_[1];
case schema::Format_HWCK:
case schema::Format_HWKC:
case schema::Format_HW:
case schema::Format_HW4:
return this->shape_[0];
default:
MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_);
MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_);
return -1;
}
}
int32_t Tensor::Width() const {
if (this->shape_.size() != 4) {
MS_LOG(ERROR) << "tensor should have 4 dim";
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
return -1;
}
switch (this->format_) {
@ -197,12 +203,24 @@ int32_t Tensor::Width() const {
return this->shape_[2];
case schema::Format_HWCK:
case schema::Format_HWKC:
case schema::Format_HW:
case schema::Format_HW4:
return this->shape_[1];
default:
return -1;
}
}
int32_t Tensor::ElementsC4Num() const {
int32_t result = 0;
if (this->shape_.size() == 4) {
result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4);
} else if (this->shape_.size() == 2) {
result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4);
}
return result;
}
std::string Tensor::ToString() const {
std::ostringstream oss;
oss << "Format: " << schema::EnumNameFormat(this->format_);
@ -235,7 +253,7 @@ std::string Tensor::ToString() const {
}
} break;
default:
oss << "Unsupport data type to print";
oss << "Unsupported data type to print";
break;
}
return oss.str();

@ -66,7 +66,7 @@ class Tensor : public mindspore::tensor::MetaTensor {
int32_t Width() const;
int32_t ElementsC4Num() const { return Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); }
int32_t ElementsC4Num() const;
int DataSize() const { return this->ElementsNum(); }

@ -37,7 +37,7 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
return RET_INPUT_TENSOR_ERROR;
}
if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) {
MS_LOG(ERROR) << "Unsupport input data type " << input->data_type();
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
return RET_INPUT_TENSOR_ERROR;
}
if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) {

@ -74,7 +74,7 @@ int CastCPUKernel::DoCast(int thread_id) {
Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset,
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
} else {
MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type;
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type;
return RET_ERROR;
}
} else {
@ -88,7 +88,7 @@ int CastCPUKernel::DoCast(int thread_id) {
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupport input data type " << input_data_type;
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR;
}
}

@ -1,21 +1,15 @@
#define SLICES 4
int DivideRoundUp(int n, int div) {
int q = n / div;
return n % div == 0 ? q : q + 1;
}
__kernel void SoftMax(__global float4 *input, __global float4 *output, const int4 input_shape) {
int X = get_global_id(0); // width
int Y = get_global_id(1); // height
int H = input_shape.y;
int W = input_shape.z;
int C = input_shape.w;
__kernel void SoftMax_BUF(__global float4 *input, __global float4 *output, const int4 input_shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int H = input_shape.x;
int W = input_shape.y;
int C = input_shape.z;
int S = input_shape.w;
if (X >= W || Y >= H) return;
float sum = 0.0f;
for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) {
for (int d = 0; d < S; ++d) {
float4 t = input[(Y * W + X * H) * C + d];
sum += exp(t.x);
if (d * 4 + 1 < C) sum += exp(t.y);
@ -23,10 +17,34 @@ __kernel void SoftMax(__global float4 *input, __global float4 *output, const int
if (d * 4 + 3 < C) sum += exp(t.w);
}
for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) {
for (int d = 0; d < S; ++d) {
float4 t = input[(Y * W + X * H) * C + d];
t = exp(t) / sum;
float4 result = convert_float4(t);
output[(Y * W + X * H) * C + d] = result;
}
}
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
if (X >= input_shape.x || Y >= input_shape.y) return;
float sum = 0.0f;
for (int d = 0; d < input_shape.w; ++d) {
float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X));
sum += exp(t.x);
if (d * 4 + 1 < input_shape.z) sum += exp(t.y);
if (d * 4 + 2 < input_shape.z) sum += exp(t.z);
if (d * 4 + 3 < input_shape.z) sum += exp(t.w);
}
for (int d = 0; d < input_shape.w; ++d) {
float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X));
t = exp(t) / sum;
float4 result = convert_float4(t);
write_imagef(output, (int2)(Y * input_shape.w + d, X), result);
}
}

@ -0,0 +1,50 @@
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
// what is mask and args.slices_x32
__kernel void SoftMax1x1_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask,
const int slices, const int slices_x32) {
int tid = get_local_id(0);
int slices_count = 0;
int offset = 0;
float sum = 0.0f;
do {
int z = offset + tid;
if (z < slices) {
float4 mask_temp = z == slices - 1 ? mask : (float4)(1.0f);
float4 src = read_imagef(input, smp_none, (int2)(0, 0));
sum += dot(mask_temp, exp(src));
offset += 32;
}
slices_count++;
} while (slices_count < slices_x32);
__local float4 tmp[8];
__local float *tmpx1 = (__local float *)tmp;
tmpx1[tid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (tid == 0) {
sum = dot((float4)(1.0f), tmp[0]);
sum += dot((float4)(1.0f), tmp[1]);
sum += dot((float4)(1.0f), tmp[2]);
sum += dot((float4)(1.0f), tmp[3]);
sum += dot((float4)(1.0f), tmp[4]);
sum += dot((float4)(1.0f), tmp[5]);
sum += dot((float4)(1.0f), tmp[6]);
sum += dot((float4)(1.0f), tmp[7]);
tmpx1[0] = 1.0f / sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
sum = tmpx1[0];
offset = 0;
slices_count = 0;
do {
int z = offset + tid;
if (z < slices) {
float4 res = convert_float4(exp(read_imagef(input, smp_none, (int2)(0, 0))) * sum);
write_imagef(output, (int2)(0, 0), res);
offset += 32;
}
slices_count++;
} while (slices_count < slices_x32);
}

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,69 +17,143 @@
#include "src/runtime/kernel/opencl/kernel/softmax.h"
#include <string>
#include <set>
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/utils.h"
#ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc"
#include "src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl.inc"
#endif
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_SoftMax;
namespace mindspore {
namespace kernel {
namespace mindspore::kernel {
std::vector<float> SoftmaxOpenCLKernel::GetMaskForLastChannel(int channels) {
std::vector<float> mask{4, 0.0f};
const int reminder = channels % 4 == 0 ? 4 : channels % 4;
for (int i = 0; i < reminder; ++i) {
mask[i] = 1.0f;
}
return mask;
}
int SoftmaxOpenCLKernel::InitGlobalSize() {
const size_t global_x = out_tensors_[0]->Height();
const size_t global_y = out_tensors_[0]->Width();
const size_t global_z = 1;
global_size_ = {global_x, global_y, global_z};
return lite::RET_OK;
}
int SoftmaxOpenCLKernel::SetWorkGroupSize() {
// set work group size
InitGlobalSize();
int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())());
local_size_ = GetCommonLocalSize(global_size_, max_work_group_size);
global_size_ = GetCommonGlobalSize(local_size_, global_size_);
return lite::RET_OK;
}
int SoftmaxOpenCLKernel::SetWorkGroupSize1x1() {
local_size_ = {32, 1, 1};
global_size_ = {32, 1, 1};
return lite::RET_OK;
}
int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t im_dst_x, im_dst_y;
if (onexone_flag_) {
im_dst_x = UP_DIV(in_tensors_[0]->shape()[1], C4NUM);
im_dst_y = 1;
} else {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height();
}
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
}
int SoftmaxOpenCLKernel::Init() {
std::string kernel_name = "SoftMax";
if (parameter_->axis_ != -1 && parameter_->axis_ != 3) {
MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_;
return -1;
}
std::string program_name = "SoftMax";
std::string source = softmax_source_fp32;
runtime_ = lite::opencl::OpenCLRuntime::GetInstance();
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) {
// support 4d tensor
onexone_flag_ = false;
} else if (in_tensors_[0]->shape().size() == 2 && parameter_->axis_ == 1) {
// support 2d tensor
kernel_name += "1x1";
program_name += "1x1";
source = softmax1x1_source_fp32;
onexone_flag_ = true;
} else {
MS_LOG(EXCEPTION) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_;
}
#ifdef PROGRAM_WITH_IL
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
runtime_->CreateKernelFromIL(kernel_(), kernel_name);
#else
if (mem_type_ == MEM_TYPE::BUF) {
kernel_name += "_BUF";
program_name += "_BUF";
} else {
kernel_name += "_IMG";
program_name += "_IMG";
}
std::set<std::string> build_options;
std::string source = softmax_source_fp32;
std::string program_name = "SoftMax";
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
runtime_->LoadSource(program_name, source);
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return 0;
return lite::RET_OK;
}
int SoftmaxOpenCLKernel::InitBuffer() { return 0; }
int SoftmaxOpenCLKernel::ReSize() { return 0; }
int SoftmaxOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto allocator = ocl_runtime->GetAllocator();
// global and local workers
const uint32_t grid_x = in_tensors_[0]->shape()[2]; // W
const uint32_t grid_y = in_tensors_[0]->shape()[1]; // H
const uint32_t grid_z = 1;
std::vector<size_t> global = {grid_x, grid_y, grid_z};
std::vector<size_t> local = {1, 1, 1};
// input and output
cl::Buffer *input = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(in_tensors_[0]->Data()));
cl::Buffer *output = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(out_tensors_[0]->Data()));
cl_int4 input_size = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2],
in_tensors_[0]->shape()[3]};
std::cout << "run" << std::endl;
// attribute
int arg_idx = 0;
ocl_runtime->SetKernelArg(kernel_, arg_idx++, *input);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, *output);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_size);
if (onexone_flag_) {
int channel_size = in_tensors_[0]->shape()[1];
int slices = UP_DIV(channel_size, C4NUM);
cl_int slices_x32 = UP_DIV(slices, 32);
auto mask_ = GetMaskForLastChannel(channel_size);
cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]};
// run opengl kernel
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
runtime_->SetKernelArg(kernel_, arg_idx++, mask);
runtime_->SetKernelArg(kernel_, arg_idx++, slices);
runtime_->SetKernelArg(kernel_, arg_idx, slices_x32);
SetWorkGroupSize1x1();
} else {
int slices = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
cl_int4 input_shape = {in_tensors_[0]->Height(), in_tensors_[0]->Width(), in_tensors_[0]->Channel(), slices};
return 0;
runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
runtime_->SetKernelArg(kernel_, arg_idx, input_shape);
SetWorkGroupSize();
}
// run opengl kernel
runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr);
return lite::RET_OK;
}
kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
@ -104,5 +178,4 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::T
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SoftMax, OpenCLSoftMaxKernelCreator)
} // namespace kernel
} // namespace mindspore
} // namespace mindspore::kernel

@ -23,29 +23,37 @@
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
#include "src/runtime/opencl/opencl_runtime.h"
namespace mindspore {
namespace kernel {
class SoftmaxOpenCLKernel : public LiteKernel {
namespace mindspore::kernel {
class SoftmaxOpenCLKernel : public OpenCLKernel {
public:
explicit SoftmaxOpenCLKernel(OpParameter *parameter,
const std::vector<lite::tensor::Tensor *> &inputs,
explicit SoftmaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs, nullptr, nullptr) {
: OpenCLKernel(parameter, inputs, outputs) {
parameter_ = reinterpret_cast<SoftmaxParameter *>(parameter);
}
~SoftmaxOpenCLKernel() override{};
~SoftmaxOpenCLKernel() override{};
int Init() override;
int ReSize() override;
int Run() override;
int InitBuffer();
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
int InitGlobalSize();
int SetWorkGroupSize1x1();
int SetWorkGroupSize();
std::vector<float> GetMaskForLastChannel(int channels);
private:
SoftmaxParameter *parameter_;
cl::Kernel kernel_;
SoftmaxParameter *parameter_;
lite::opencl::OpenCLRuntime *runtime_;
enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG};
bool onexone_flag_{false};
std::vector<size_t> local_size_;
std::vector<size_t> global_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_

@ -175,4 +175,3 @@ std::string CLErrorCode(cl_int error_code) {
}
} // namespace kernel
} // namespace mindspore

@ -85,4 +85,3 @@ std::string CLErrorCode(cl_int error_code);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_

@ -19,6 +19,7 @@
#include "utils/log_adapter.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/opencl/utils.h"
namespace mindspore::lite::opencl {
@ -128,7 +129,7 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t>& img_size)
cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format,
img_size[0], img_size[1], 0, nullptr, &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")";
MS_LOG(ERROR) << "Create OpenCL Image2D failed!" << kernel::CLErrorCode(ret);
UnLock();
delete buffer;
return nullptr;
@ -187,7 +188,7 @@ void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::v
cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
image_format, img_size[0], img_size[1], 0, data, &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")";
MS_LOG(ERROR) << "Create OpenCL Image2D failed - " << kernel::CLErrorCode(ret);
UnLock();
delete buffer;
return nullptr;

@ -52,6 +52,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso
std::vector<size_t> img_size;
op_kernel->GetImageSize(i, &img_size);
auto data_ptr = op_allocator->Malloc(output->Size(), img_size);
output->SetData(data_ptr);
} else {
output->MallocData(allocator);
@ -109,7 +110,7 @@ int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format
case kNumberTypeFloat32:
return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir);
default:
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format);
return RET_ERROR;
}
@ -160,7 +161,7 @@ int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema
// TODO(wandongdong): add support !!
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in float32";
return RET_ERROR;
}
@ -194,7 +195,7 @@ int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema:
allocator_->Free(src_data);
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in float32";
return RET_ERROR;
}
@ -216,7 +217,7 @@ int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schem
allocator_->Free(src_data);
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in float32";
return RET_ERROR;
}
@ -228,7 +229,7 @@ int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::F
MS_ASSERT(4 == tensor->shape().size());
// auto src_format = tensor->GetFormat();
// todo
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in uint8";
return RET_ERROR;
}

@ -17,76 +17,90 @@
#include <memory>
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
namespace mindspore {
class TestSoftmaxOpenCL : public mindspore::CommonTest {};
void InitSoftaxParam(SoftmaxParameter *param) { param->axis_ = -1; }
TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) {
std::cout << "======" << std::endl;
MS_LOG(INFO) << "start TEST_F TestSoftmaxOpenCL";
void RunTestCase(std::vector<int> input_shape, std::vector<int> output_shape, std::string input_file,
std::string expect_file, SoftmaxParameter *param, schema::Format format) {
std::cout << "runtime" << std::endl;
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
MS_LOG(INFO) << "create SoftmaxParameter";
auto param = new SoftmaxParameter();
InitSoftaxParam(param);
// define tensor
MS_LOG(INFO) << "defineTensor";
std::cout << "defineTensor" << std::endl;
MS_LOG(INFO) << "create Tensors";
std::vector<int> shape_in = {1, 2, 2, 1};
std::vector<int> shape_out = {1, 2, 2, 1};
auto data_type = kNumberTypeFloat32;
auto tensorType = schema::NodeType_ValueNode;
lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NCHW, tensorType);
lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NCHW, tensorType);
std::vector<lite::tensor::Tensor *> inputs{tensor_in};
std::vector<lite::tensor::Tensor *> outputs{tensor_out};
auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, format, tensorType);
auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, format, tensorType);
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
MS_LOG(INFO) << "create OpenCL Kernel";
auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
Softmax_kernel->Init();
std::vector<kernel::LiteKernel *> kernels{Softmax_kernel};
// run
MS_LOG(INFO) << "NewOpenCLKernel";
std::cout << "NewOpenCLKernel" << std::endl;
auto *kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
MS_LOG(INFO) << "KernelInit";
std::cout << "KernelInit" << std::endl;
kernel->Init();
MS_LOG(INFO) << "create SubGraphOpenCLKernel";
std::cout << "LiteKernel" << std::endl;
std::vector<kernel::LiteKernel *> kernels{kernel};
inputs[0]->MallocData(allocator);
std::cout << "SubGraphOpenCLKernel" << std::endl;
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
MS_LOG(INFO) << "pGraphinit";
pGraph->Init();
MS_LOG(INFO) << "initialize data";
std::vector<lite::tensor::Tensor *> tensor_map = {tensor_in};
for (auto &tensor_file : tensor_map) {
auto tensor = tensor_file;
size_t size = tensor->Size();
const float data[4] = {std::log(1.0f), std::log(2.0f), std::log(3.0f), std::log(4.0f)};
memcpy(tensor->Data(), data, size);
// load data
MS_LOG(INFO) << "load data1";
LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file);
auto *input_data = reinterpret_cast<float *>(input_tensor->Data());
printf("\ninput[0:10]:");
for (int i = 0; i < 10; i++) {
printf("[%d]:%.3f ", i, input_data[i]);
}
printf("\n\n");
MS_LOG(INFO) << "pGraph->Run()";
MS_LOG(INFO) << "Run";
pGraph->Run();
MS_LOG(INFO) << "==================output data=================";
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
size_t output_size = tensor_out->Size();
MS_LOG(INFO) << "compare result";
std::cout << "compare result" << std::endl;
CompareOutput(output_tensor, expect_file);
}
printf("output:");
for (int i = 0; i < 4; i++) {
printf("%.3f ", output_data[i]);
}
printf("\n");
float expect[4] = {1.0f, 2.0f, 3.0f, 4.0f};
TEST_F(TestSoftmaxOpenCL, Softmax_1) {
std::vector<int> input_shape = {1, 2, 2, 8};
std::vector<int> output_shape = {1, 2, 2, 8};
std::string input_file = "softmax_in.bin";
std::string expect_file = "softmax_out.bin";
auto param = new SoftmaxParameter;
param->axis_ = 3;
schema::Format format = schema::Format_NHWC4;
for (int i = 0; i < tensor_out->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect[i]) > 1e-5) {
printf("idx[%d] except=%.3f output=%.3f .", i, expect[i], output_data[i]);
}
}
printf("\nTest all close OK for %zu!\n", output_size);
lite::CompareOutputData(output_data, expect, 4);
RunTestCase(input_shape, output_shape, input_file, expect_file, param, format);
}
// TEST_F(TestSoftmaxOpenCL, Softmax_1x1) {
// std::vector<int> input_shape = {1, 100};
// std::vector<int> output_shape = {1, 100};
// std::string input_file = "softmax1x1_in.bin";
// std::string expect_file = "softmax1x1_out.bin";
// auto param = new SoftmaxParameter;
// param->axis_ = 1;
// schema::Format format = schema::Format_NHWC4;
//
// RunTestCase(input_shape, output_shape, input_file, expect_file, param, format);
//}
} // namespace mindspore

@ -40,13 +40,13 @@ void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_
size_t output_size = output_tensor->Size();
float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
printf("output[0:10]:");
for (int i = 0; i < 10; i++) {
printf("output[0:12]:");
for (int i = 0; i < 12; i++) {
printf("[%d]:%.3f ", i, output_data[i]);
}
printf("\n");
printf("expect[0:10]:");
for (int i = 0; i < 10; i++) {
printf("expect[0:12]:");
for (int i = 0; i < 12; i++) {
printf("[%d]:%.3f ", i, expect_data[i]);
}
printf("\n");

@ -157,7 +157,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_CHWK;
} else {
MS_LOG(ERROR) << "unsupport format";
MS_LOG(ERROR) << "Unsupported format";
return -1;
}
} break;

@ -184,7 +184,7 @@ size_t GetDataTypeSize(const TypeId &data_type) {
return sizeof(int64_t);
default:
MS_LOG(ERROR) << data_type;
MS_LOG(ERROR) << "unsupport datatype";
MS_LOG(ERROR) << "Unsupported datatype";
return RET_ERROR;
}
}

Loading…
Cancel
Save