!5435 [MS][LITE][GPU]reshape transpose fp16

Merge pull request !5435 from chenzupeng/master-lite
pull/5435/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit fa50dd025c

@ -1,3 +1,4 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void reshape(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
int X = get_global_id(0);

@ -1,3 +1,4 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void transpose_IMG(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) {
int X = get_global_id(0);
@ -75,8 +76,8 @@ __kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_dat
result[3].z = x2.w;
result[3].w = x3.w;
dst_data[4 * Y * HW.y + X] = result[0];
dst_data[(4 * Y + 1) * HW.y + X] = result[1];
dst_data[(4 * Y + 2) * HW.y + X] = result[2];
dst_data[(4 * Y + 3) * HW.y + X] = result[3];
if (4 * Y < C.x) dst_data[4 * Y * HW.y + X] = result[0];
if (4 * Y + 1 < C.x) dst_data[(4 * Y + 1) * HW.y + X] = result[1];
if (4 * Y + 2 < C.x) dst_data[(4 * Y + 2) * HW.y + X] = result[2];
if (4 * Y + 3 < C.x) dst_data[(4 * Y + 3) * HW.y + X] = result[3];
}

@ -33,6 +33,7 @@ namespace mindspore::kernel {
int ReshapeOpenCLKernel::Init() {
std::string kernel_name = "reshape";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
enable_fp16_ = ocl_runtime->GetFp16Enable();
in_ori_format_ = in_tensors_[0]->GetFormat();
out_ori_format_ = out_tensors_[0]->GetFormat();
if (in_ori_format_ != schema::Format_NHWC4 && in_ori_format_ != schema::Format_NHWC) {
@ -73,11 +74,10 @@ int ReshapeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size)
int c = shapex[3];
im_dst_x = w * UP_DIV(c, C4NUM);
im_dst_y = h;
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
@ -121,4 +121,5 @@ kernel::LiteKernel *OpenCLReshapeKernelCreator(const std::vector<lite::tensor::T
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Reshape, OpenCLReshapeKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Reshape, OpenCLReshapeKernelCreator)
} // namespace mindspore::kernel

@ -38,6 +38,7 @@ class ReshapeOpenCLKernel : public OpenCLKernel {
private:
cl::Kernel kernel_;
bool enable_fp16_{false};
};
} // namespace mindspore::kernel

@ -35,6 +35,7 @@ namespace mindspore::kernel {
int TransposeOpenCLKernel::Init() {
std::string kernel_name = "transpose";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
enable_fp16_ = ocl_runtime->GetFp16Enable();
if (!is_image_out_) {
kernel_name += "_BUF";
} else {
@ -70,11 +71,10 @@ int TransposeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_siz
size_t im_dst_x, im_dst_y;
im_dst_x = UP_DIV(out_tensors_[0]->Height() * out_tensors_[0]->Width(), C4NUM);
im_dst_y = out_tensors_[0]->Channel();
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
@ -82,6 +82,7 @@ int TransposeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_siz
}
int TransposeOpenCLKernel::Run() {
// notice: input image2d size = {c/4, h * w}
MS_LOG(DEBUG) << this->name() << " Running!";
std::vector<int> shapex = in_tensors_[0]->shape();
int h = shapex[1];

@ -38,7 +38,8 @@ class TransposeOpenCLKernel : public OpenCLKernel {
private:
cl::Kernel kernel_;
bool is_image_out_ = false;
bool is_image_out_{false};
bool enable_fp16_{false};
};
} // namespace mindspore::kernel

@ -31,14 +31,14 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest {
};
void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, void *weight_data, void *bias_data,
void *output_data, bool fp16) {
void *output_data, bool enable_fp16) {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
size_t dtype_size = sizeof(float);
if (fp16) {
if (enable_fp16) {
ocl_runtime->SetFp16Enable(true);
dtype_size = sizeof(float16_t);
}
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
int pad = shape[0];
int n = shape[1];
@ -52,7 +52,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data,
int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1;
std::vector<int> input_shape = {n, h, w, ci};
auto tensor_x_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape);
std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
@ -61,7 +61,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data,
std::vector<int> weight_shape = {co, kh, kw, ci};
auto tensor_w_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape);
std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape);
auto tensor_w = tensor_w_ptr.get();
if (tensor_w == nullptr) {
MS_LOG(ERROR) << "tensor_w create error.";
@ -71,7 +71,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data,
std::vector<int> bias_shape = {co};
auto tensor_bias_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape);
std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape);
auto tensor_bias = tensor_bias_ptr.get();
if (tensor_bias == nullptr) {
MS_LOG(ERROR) << "tensor_bias create error.";
@ -81,7 +81,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data,
std::vector<int> out_shape = {1, oh, ow, co};
auto tensor_out_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape);
std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
@ -126,7 +126,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data,
pGraph->Init();
memcpy(inputs[0]->Data(), input_data, n * h * w * ci * dtype_size);
pGraph->Run();
if (fp16) {
if (enable_fp16) {
CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast<float16_t>(1e-3), 2e-2);
} else {
CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast<float>(1e-5));
@ -137,7 +137,8 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data,
lite::opencl::OpenCLRuntime::DeleteInstance();
}
void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) {
void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<std::string> file_path,
bool enable_fp16) {
size_t input_size;
std::string input_path = file_path[0];
auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size);
@ -168,7 +169,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<
MS_LOG(ERROR) << "output_data load error.";
return;
}
RunTestCaseConv2dTranspose(shape, input_data, weight_data, bias_data, output_data, fp16);
RunTestCaseConv2dTranspose(shape, input_data, weight_data, bias_data, output_data, enable_fp16);
}
TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {

@ -29,32 +29,21 @@ class TestMatMulOpenCL : public mindspore::CommonTest {
TestMatMulOpenCL() {}
};
void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) {
void RunTestCaseMatMul(const std::vector<int> &shape, void *input_data, void *weight_data, void *output_data,
bool enable_fp16) {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
if (fp16) {
size_t dtype_size = sizeof(float);
if (enable_fp16) {
ocl_runtime->SetFp16Enable(true);
dtype_size = sizeof(float16_t);
}
auto allocator = ocl_runtime->GetAllocator();
size_t input_size;
int ci = shape[0];
int co = shape[1];
std::string input_path = file_path[0];
auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size);
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data load error.";
return;
}
size_t weight_size;
std::string weight_path = file_path[1];
auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size);
if (weight_data == nullptr) {
MS_LOG(ERROR) << "weight_data load error.";
return;
}
std::vector<int> input_shape = {1, ci};
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32),
input_shape, schema::Format_NC);
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(
TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NC);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
@ -63,7 +52,7 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri
std::vector<int> w_shape = {co, ci};
auto tensor_w_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), w_shape);
std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), w_shape);
auto tensor_w = tensor_w_ptr.get();
if (tensor_w == nullptr) {
MS_LOG(ERROR) << "tensor_w create error.";
@ -72,8 +61,8 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri
tensor_w->SetData(weight_data);
std::vector<int> out_shape = {1, co};
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32),
out_shape, schema::Format_NC);
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(
TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NC);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
@ -100,12 +89,12 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri
return;
}
pGraph->Init();
memcpy(inputs[0]->Data(), input_data, input_size);
memcpy(inputs[0]->Data(), input_data, ci * dtype_size);
pGraph->Run();
if (fp16) {
CompareOutput(tensor_out, file_path[2], static_cast<float16_t>(1e-3), 2e-2);
if (enable_fp16) {
CompareOutput(outputs[0]->Data(), output_data, co, static_cast<float16_t>(1e-3), 2e-2);
} else {
CompareOutput(tensor_out, file_path[2], static_cast<float>(1e-5));
CompareOutput(outputs[0]->Data(), output_data, co, static_cast<float>(1e-5));
}
tensor_x->SetData(nullptr);
@ -114,6 +103,31 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri
lite::opencl::OpenCLRuntime::DeleteInstance();
}
void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::string> file_path, bool enable_fp16) {
size_t input_size;
std::string input_path = file_path[0];
auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size);
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data load error.";
return;
}
size_t weight_size;
std::string weight_path = file_path[1];
auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size);
if (weight_data == nullptr) {
MS_LOG(ERROR) << "weight_data load error.";
return;
}
size_t output_size;
std::string output_path = file_path[2];
auto output_data = mindspore::lite::ReadFile(output_path.c_str(), &output_size);
if (output_data == nullptr) {
MS_LOG(ERROR) << "output_data load error.";
return;
}
RunTestCaseMatMul(shape, input_data, weight_data, output_data, enable_fp16);
}
TEST_F(TestMatMulOpenCL, MatMulFp32) {
int ci = 1280;
int co = 1001;
@ -133,4 +147,26 @@ TEST_F(TestMatMulOpenCL, MatMulFp16) {
"./test_data/matmul/matmul_fp16_output.bin"};
RunTestCaseMatMul(shape, file_path, true);
}
TEST_F(TestMatMulOpenCL, MatMulFp32_2) {
int ci = 5;
int co = 3;
std::vector<int> shape = {ci, co};
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f};
std::vector<float> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
std::vector<float> output_data = {10.f, 10.f, 10.f};
RunTestCaseMatMul(shape, input_data.data(), weight_data.data(), output_data.data(), false);
}
TEST_F(TestMatMulOpenCL, MatMulFp16_2) {
int ci = 5;
int co = 3;
std::vector<int> shape = {ci, co};
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f};
std::vector<float16_t> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
std::vector<float16_t> output_data = {10.f, 10.f, 10.f};
RunTestCaseMatMul(shape, input_data.data(), weight_data.data(), output_data.data(), true);
}
} // namespace mindspore

@ -21,6 +21,7 @@
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
namespace mindspore {
class TestReshapeOpenCL : public mindspore::CommonTest {
@ -28,29 +29,27 @@ class TestReshapeOpenCL : public mindspore::CommonTest {
TestReshapeOpenCL() {}
};
TEST_F(TestReshapeOpenCL, ReshapeFp32) {
void RunTestCaseReshape(const std::vector<int> &shape, void *input_data, void *output_data, bool enable_fp16) {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
int c = 63;
size_t input_size;
std::string input_path = "./test_data/reshape/reshape_fp32_input.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data load error.";
return;
size_t dtype_size = sizeof(float);
if (enable_fp16) {
ocl_runtime->SetFp16Enable(true);
dtype_size = sizeof(float16_t);
}
auto allocator = ocl_runtime->GetAllocator();
int c = shape[0];
std::vector<int> input_shape = {1, 1, 1, c};
auto tensor_x_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC);
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(
TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NHWC);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}
std::vector<int> out_shape = {1, c};
auto tensor_out_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NC);
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(
TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NC);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
@ -76,36 +75,36 @@ TEST_F(TestReshapeOpenCL, ReshapeFp32) {
return;
}
pGraph->Init();
memcpy(inputs[0]->Data(), input_data, input_size);
memcpy(inputs[0]->Data(), input_data, c * dtype_size);
pGraph->Run();
size_t output_size;
std::string output_path = "./test_data/reshape/reshape_fp32_output.bin";
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
if (correct_data == nullptr) {
MS_LOG(ERROR) << "correct_data create error.";
return;
}
printf("==================output data=================\n");
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
std::cout << std::endl;
int size_n = c;
size_n = size_n > 100 ? 100 : size_n;
for (int i = 0; i < size_n; i++) {
std::cout << output_data[i] << " ";
if ((i + 1) % c == 0) {
std::cout << std::endl;
if (enable_fp16) {
CompareOutput(outputs[0]->Data(), output_data, c, static_cast<float16_t>(1e-3), 2e-2);
} else {
CompareOutput(outputs[0]->Data(), output_data, c, static_cast<float>(1e-5));
}
}
std::cout << std::endl;
// compare
CompareOutputData(output_data, correct_data, c, 0.00001);
inputs[0]->SetData(nullptr);
outputs[0]->SetData(nullptr);
MS_LOG(INFO) << "Test ReshapeFp32 passed";
lite::opencl::OpenCLRuntime::DeleteInstance();
}
TEST_F(TestReshapeOpenCL, ReshapeFp32) {
int c = 7;
std::vector<int> shape = {c};
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
std::vector<float> output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
RunTestCaseReshape(shape, input_data.data(), output_data.data(), false);
}
TEST_F(TestReshapeOpenCL, ReshapeFp16) {
int c = 7;
std::vector<int> shape = {c};
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
std::vector<float16_t> output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
RunTestCaseReshape(shape, input_data.data(), output_data.data(), true);
}
} // namespace mindspore

@ -21,6 +21,7 @@
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
namespace mindspore {
class TestTransposeOpenCL : public mindspore::CommonTest {
@ -28,31 +29,29 @@ class TestTransposeOpenCL : public mindspore::CommonTest {
TestTransposeOpenCL() {}
};
TEST_F(TestTransposeOpenCL, TransposeFp32) {
void RunTestTranspose(const std::vector<int> &shape, void *input_data, void *output_data, bool enable_fp16) {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
int h = 64;
int w = 1;
int c = 7360;
size_t input_size;
std::string input_path = "./test_data/transpose/transpose_fp32_input.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data load error.";
return;
size_t dtype_size = sizeof(float);
if (enable_fp16) {
ocl_runtime->SetFp16Enable(true);
dtype_size = sizeof(float16_t);
}
auto allocator = ocl_runtime->GetAllocator();
int h = shape[0];
int w = shape[1];
int c = shape[2];
std::vector<int> input_shape = {1, h, w, c};
auto tensor_x_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC);
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(
TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NHWC);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}
std::vector<int> out_shape = {1, c, h, w};
auto tensor_out_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NCHW);
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(
TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NCHW);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
@ -78,9 +77,35 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
return;
}
pGraph->Init();
memcpy(inputs[0]->Data(), input_data, input_size);
memcpy(inputs[0]->Data(), input_data, h * w * c * dtype_size);
pGraph->Run();
if (enable_fp16) {
CompareOutput(outputs[0]->Data(), output_data, h * w * c, static_cast<float16_t>(1e-3), 2e-2);
} else {
CompareOutput(outputs[0]->Data(), output_data, h * w * c, static_cast<float>(1e-5));
}
inputs[0]->SetData(nullptr);
outputs[0]->SetData(nullptr);
MS_LOG(INFO) << "Test TransposeFp32 passed";
lite::opencl::OpenCLRuntime::DeleteInstance();
}
TEST_F(TestTransposeOpenCL, TransposeFp32) {
int h = 64;
int w = 1;
int c = 7360;
std::vector<int> shape = {h, w, c};
size_t input_size;
std::string input_path = "./test_data/transpose/transpose_fp32_input.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data load error.";
return;
}
size_t output_size;
std::string output_path = "./test_data/transpose/transpose_fp32_output.bin";
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
@ -88,26 +113,17 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
MS_LOG(ERROR) << "correct_data create error.";
return;
}
printf("==================output data=================\n");
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
std::cout << std::endl;
int size_n = h * w * c;
size_n = size_n > 100 ? 100 : size_n;
for (int i = 0; i < size_n; i++) {
std::cout << output_data[i] << " ";
if ((i + 1) % c == 0) {
std::cout << std::endl;
RunTestTranspose(shape, input_data, correct_data, false);
}
}
std::cout << std::endl;
// compare
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
inputs[0]->SetData(nullptr);
outputs[0]->SetData(nullptr);
TEST_F(TestTransposeOpenCL, TransposeFp16) {
int h = 4;
int w = 1;
int c = 3;
std::vector<int> shape = {h, w, c};
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
std::vector<float16_t> output_data = {0.0f, 3.0f, 6.0f, 9.0f, 1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f, 8.0f, 11.0f};
MS_LOG(INFO) << "Test TransposeFp32 passed";
lite::opencl::OpenCLRuntime::DeleteInstance();
RunTestTranspose(shape, input_data.data(), output_data.data(), true);
}
} // namespace mindspore

Loading…
Cancel
Save