new ops named pad

arithmetic_self ops supported dim2
pull/7141/head
Pengyongrong 5 years ago
parent 9e48d6527a
commit 7a536ad917

@ -3,29 +3,34 @@ __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP
__kernel void gather_NHWC4(__read_only image2d_t src_data, __global int *indices, __write_only image2d_t dst_data,
int4 src_size, int4 dst_size, int indices_num, int axis) {
int X = get_global_id(0); // w
int Y = get_global_id(1); // h
int Y = get_global_id(1); // n*h
int Z = get_global_id(2); // c
if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {
if (X >= dst_size.x || Y >= dst_size.y * dst_size.w || Z >= dst_size.z || dst_size.y == 0) {
return;
}
FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
if (axis == 1) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + Z, indices[Y]));
int batch = Y / dst_size.y;
int height = Y % dst_size.y;
if (axis == 0) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + Z, indices[batch] * src_size.y + height));
} else if (axis == 1) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + Z, batch * src_size.y + indices[height]));
} else if (axis == 2) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(indices[X] * src_size.z + Z, Y));
res_data = READ_IMAGE(src_data, smp_zero, (int2)(indices[X] * src_size.z + Z, batch * src_size.y + height));
} else if (axis == 3) {
int offset[4] = {indices[Z * 4] / 4, indices[Z * 4 + 1] / 4, indices[Z * 4 + 2] / 4, indices[Z * 4 + 3] / 4};
FLT tmp[4];
FLT res_tmp[4];
for (int i = 0; i < 4; ++i) {
for (int i = 0; i < indices_num; ++i) {
FLT4 rd_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + offset[i], batch * src_size.y + height));
if (i >= 1 && offset[i] != offset[i - 1]) {
FLT4 rd_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + offset[i], Y));
tmp[0] = rd_data.x;
tmp[1] = rd_data.y;
tmp[2] = rd_data.z;
tmp[3] = rd_data.w;
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + offset[i], batch * src_size.y + height));
}
tmp[0] = rd_data.x;
tmp[1] = rd_data.y;
tmp[2] = rd_data.z;
tmp[3] = rd_data.w;
res_tmp[i] = tmp[indices[Z * 4 + i] % 4];
}
res_data.x = res_tmp[0];
@ -33,34 +38,44 @@ __kernel void gather_NHWC4(__read_only image2d_t src_data, __global int *indices
res_data.z = res_tmp[2];
res_data.w = res_tmp[3];
}
WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res_data);
WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, batch * dst_size.y + height), res_data);
}
__kernel void gather_NC4HW4(__read_only image2d_t src_data, __global int *indices, __write_only image2d_t dst_data,
int4 src_size, int4 dst_size, int indices_num, int axis) {
int4 src_size, int4 dst_size, int indices_num, int axis) {
int X = get_global_id(0); // w
int Y = get_global_id(1); // h
int Y = get_global_id(1); // n*h
int Z = get_global_id(2); // c
if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {
if (X >= dst_size.x || Y >= dst_size.y * dst_size.w || Z >= dst_size.z || dst_size.y == 0) {
return;
}
FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
if (axis == 1) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X, Z * dst_size.y + indices[Y]));
int batch = Y / dst_size.y;
int height = Y % dst_size.y;
if (axis == 0) {
int index_y = indices[batch] * src_size.y * src_size.z + Z * src_size.y + height;
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X, index_y));
} else if (axis == 1) {
int index_y = batch * src_size.y * src_size.z + Z * src_size.y + indices[height];
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X, index_y));
} else if (axis == 2) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(indices[X], Z * dst_size.y + Y));
int index_y = batch * src_size.y * src_size.z + Z * src_size.y + height;
res_data = READ_IMAGE(src_data, smp_zero, (int2)(indices[X], index_y));
} else if (axis == 3) {
int offset[4] = {indices[Z * 4] / 4, indices[Z * 4 + 1] / 4, indices[Z * 4 + 2] / 4, indices[Z * 4 + 3] / 4};
FLT tmp[4];
FLT res_tmp[4];
for (int i = 0; i < 4; ++i) {
for (int i = 0; i < indices_num; ++i) {
FLT4 rd_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int index_y = batch * src_size.y * src_size.z + offset[i] * src_size.y + height;
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X, index_y));
if (i >= 1 && offset[i] != offset[i - 1]) {
FLT4 rd_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X, offset[i] * dst_size.y + Y));
tmp[0] = rd_data.x;
tmp[1] = rd_data.y;
tmp[2] = rd_data.z;
tmp[3] = rd_data.w;
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X, index_y));
}
tmp[0] = rd_data.x;
tmp[1] = rd_data.y;
tmp[2] = rd_data.z;
tmp[3] = rd_data.w;
res_tmp[i] = tmp[indices[Z * 4 + i] % 4];
}
res_data.x = res_tmp[0];
@ -68,5 +83,5 @@ __kernel void gather_NC4HW4(__read_only image2d_t src_data, __global int *indice
res_data.z = res_tmp[2];
res_data.w = res_tmp[3];
}
WRITE_IMAGE(dst_data, (int2)(X, (Z * dst_size.y + Y)), res_data);
WRITE_IMAGE(dst_data, (int2)(X, (batch * dst_size.y * dst_size.z + Z * dst_size.y + height)), res_data);
}

@ -40,14 +40,25 @@ using mindspore::schema::PrimitiveType_Square;
namespace mindspore::kernel {
int ArithmeticSelfOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
auto out_shape = out_tensors_[0]->shape();
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height() * out_tensors_[0]->Batch();
if (in_tensors_[0]->shape().size() == 4) {
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height() * out_tensors_[0]->Batch();
} else {
im_dst_x = UP_DIV(out_shape[1], C4NUM);
im_dst_y = out_tensors_[0]->Batch();
}
} else {
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
if (in_tensors_[0]->shape().size() == 4) {
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
} else {
im_dst_y = out_tensors_[0]->Batch() * UP_DIV(out_shape[1], C4NUM);
im_dst_x = 1;
}
}
size_t img_dtype = CL_FLOAT;
auto enable_fp16_ = ocl_runtime_->GetFp16Enable();
@ -107,14 +118,14 @@ void ArithmeticSelfOpenCLKernel::GetKernelName(std::string *kernel_name, Arithme
}
int ArithmeticSelfOpenCLKernel::Init() {
if (in_tensors_[0]->shape().size() != 4) {
MS_LOG(ERROR) << " only support dim = 4 ";
if (in_tensors_[0]->shape().size() != 4 && in_tensors_[0]->shape().size() != 2) {
MS_LOG(ERROR) << " only support dim = 4 or 2 but your dim = " << in_tensors_[0]->shape().size();
return RET_ERROR;
}
auto param = reinterpret_cast<ArithmeticSelfParameter *>(this->op_parameter_);
auto in_format = op_format_;
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4 && in_format != schema::Format_NC4) {
MS_LOG(ERROR) << "input format(" << in_format << ") "
<< "format not support!";
return RET_ERROR;
@ -161,12 +172,19 @@ int ArithmeticSelfOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto output_shape = out_tensors_[0]->shape();
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], UP_DIV(output_shape[3], C4NUM)};
uint32_t OH = output_shape[0] * output_shape[1]; // N*H
uint32_t OW = output_shape[2];
uint32_t OC = UP_DIV(output_shape[3], C4NUM);
cl_int4 output_shape_ = {};
uint32_t OH = 1, OW = 1, OC = 1;
if (output_shape.size() == 4) {
output_shape_ = {output_shape[0], output_shape[1], output_shape[2], UP_DIV(output_shape[3], C4NUM)};
OH = output_shape[0] * output_shape[1];
OW = output_shape[2];
OC = UP_DIV(output_shape[3], C4NUM);
} else if (output_shape.size() == 2) {
output_shape_ = {output_shape[0], 1, 1, UP_DIV(output_shape[1], C4NUM)};
OH = output_shape[0];
OW = 1;
OC = UP_DIV(output_shape[1], C4NUM);
}
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1}; // init local
std::vector<size_t> global = {OH, OW, OC};

@ -68,6 +68,7 @@ int GatherOpenCLKernel::Init() {
}
return RET_OK;
}
int GatherOpenCLKernel::InitBuffer() {
auto indices_tensor = in_tensors_.at(1);
int indices_num = indices_tensor->ElementsNum();
@ -94,13 +95,15 @@ int GatherOpenCLKernel::InitBuffer() {
}
return RET_OK;
}
int GatherOpenCLKernel::ReSize() { return RET_OK; }
int GatherOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (in_tensors_[0]->GetFormat() == schema::Format::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height();
im_dst_y = out_tensors_[0]->Height() * out_tensors_[0]->Batch();
} else {
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
@ -115,6 +118,7 @@ int GatherOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size)
*img_size = std::move(vec);
return RET_OK;
}
int GatherOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto param = reinterpret_cast<GatherParameter *>(this->op_parameter_);
@ -122,7 +126,6 @@ int GatherOpenCLKernel::Run() {
if (InitBuffer() != RET_OK) {
return RET_ERROR;
}
auto input_shape = in_tensors_[0]->shape();
auto output_shape = out_tensors_[0]->shape();
int indices_num = in_tensors_[1]->ElementsNum();
@ -132,7 +135,8 @@ int GatherOpenCLKernel::Run() {
cl_int4 dst_size = {(cl_int)out_tensors_[0]->Width(), (cl_int)out_tensors_[0]->Height(), (cl_int)CO4,
(cl_int)out_tensors_[0]->Batch()};
std::vector<size_t> local = {1, 1, 1};
std::vector<size_t> global = {(size_t)out_tensors_[0]->Width(), (size_t)out_tensors_[0]->Height(), CO4};
std::vector<size_t> global = {(size_t)out_tensors_[0]->Width(),
(size_t)out_tensors_[0]->Batch() * (size_t)out_tensors_[0]->Height(), CO4};
int arg_cn = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c(), lite::opencl::MemType::IMG);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, indices_data_, lite::opencl::MemType::BUF);

@ -82,7 +82,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) {
}
return;
}
param->op_parameter_.type_ = schema::PrimitiveType_Round;
param->op_parameter_.type_ = schema::PrimitiveType_Sin;
auto *arithmeticself_kernel =
new (std::nothrow) kernel::ArithmeticSelfOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (arithmeticself_kernel == nullptr) {
@ -225,4 +225,99 @@ TEST_F(TestArithmeticSelfOpenCLCI, ArithmeticSelfRound) {
}
delete sub_graph;
}
TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfdim2Fp16) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->SetFp16Enable(true);
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
// get the input from .bin
size_t input1_size, output_size;
std::string input1Ppath = "./test_data/in_arithmetic_selffp16.bin";
std::string correctOutputPath = "./test_data/out_arithmetic_selffp16.bin";
auto input_data1 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto correctOutput =
reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape = {1, 512};
auto data_type = kNumberTypeFloat16;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
auto *input_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NC, tensor_type);
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NC, tensor_type);
if (input_tensor == nullptr || output_tensor == nullptr) {
MS_LOG(INFO) << " new input_tensor or output_tensor failed ";
return;
}
std::vector<lite::Tensor *> inputs{input_tensor};
std::vector<lite::Tensor *> outputs{output_tensor};
MS_LOG(INFO) << " initialize param ";
auto param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ConcatParameter failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
return;
}
param->op_parameter_.type_ = schema::PrimitiveType_Sin;
auto *arithmeticself_kernel =
new (std::nothrow) kernel::ArithmeticSelfOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (arithmeticself_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::ArithmeticSelfOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
return;
}
arithmeticself_kernel->SetFormatType(schema::Format_NC4HW4);
arithmeticself_kernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{arithmeticself_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete arithmeticself_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
memcpy(inputs[0]->data_c(), input_data1, input1_size);
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->data_c());
CompareOutputData1(input_data1, output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
}
} // namespace mindspore

@ -17,6 +17,7 @@
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
@ -27,6 +28,7 @@ class TestGatherOpenCL : public mindspore::CommonTest {
public:
TestGatherOpenCL() {}
};
template <typename T>
void test_main_gather(void *input_data, void *correct_data, const std::vector<int> &input_shape,
const std::vector<int> &indices, GatherParameter *param, TypeId data_type,
@ -41,7 +43,7 @@ void test_main_gather(void *input_data, void *correct_data, const std::vector<in
output_shape[param->axis_] = indices.size();
auto tensor_a = lite::Tensor(TypeId(data_type), input_shape, format);
auto tensor_b = lite::Tensor(TypeId(data_type), indices_shape, schema::Format_NC);
auto tensor_b = lite::Tensor(kNumberTypeInt32, indices_shape, schema::Format_NC);
auto tensor_c = lite::Tensor(TypeId(data_type), output_shape, format);
std::vector<lite::Tensor *> inputs{&tensor_a, &tensor_b};
std::vector<lite::Tensor *> outputs{&tensor_c};
@ -53,6 +55,7 @@ void test_main_gather(void *input_data, void *correct_data, const std::vector<in
MS_LOG(INFO) << "new GatherOpenCLKernel failed ";
return;
}
pkernel->SetFormatType(schema::Format_NC4HW4);
pkernel->Init();
// to do allocate memory for inputs and outputs
@ -72,18 +75,56 @@ void test_main_gather(void *input_data, void *correct_data, const std::vector<in
MS_LOG(INFO) << " init tensors ";
memcpy(inputs[0]->data_c(), input_data, input_size);
auto input1_tensor = reinterpret_cast<int *>(inputs[1]->data_c());
for (int i = 0; i < inputs[1]->ElementsNum(); ++i) {
input1_tensor[i] = indices.at(i);
}
sub_graph->Run();
std::cout << "==================output data================" << std::endl;
auto *output_data = reinterpret_cast<T *>(outputs[0]->data_c());
CommonTest::CompareOutputData<T>(output_data, static_cast<T *>(correct_data), outputs[0]->ElementsNum(), 0.0001);
delete sub_graph;
CommonTest::CompareOutputData(output_data, static_cast<T *>(correct_data), outputs[0]->ElementsNum(), 0.0001);
}
TEST_F(TestGatherOpenCL, Axis0Fp16) {
std::vector<int> input_shape{5, 10, 10, 5};
std::vector<int> indices{1, 0, 3, 4};
GatherParameter *param = std::make_unique<GatherParameter>().release();
param->axis_ = 0;
size_t input_size, output_size;
std::string inputPpath = "./test_data/gatherfp16_input.bin";
std::string correctOutputPath = "./test_data/gatherfp16_output.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(inputPpath.c_str(), &input_size));
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
if (param == nullptr) {
return;
}
TypeId data_type = kNumberTypeFloat16;
schema::Format format = schema::Format_NHWC;
test_main_gather<float16_t>(input_data, correct_data, input_shape, indices, param, data_type, format);
}
TEST_F(TestGatherOpenCL, Axis0Fp32) {
std::vector<int> input_shape{5, 10, 10, 5};
std::vector<int> indices{1, 2, 3, 4};
GatherParameter *param = std::make_unique<GatherParameter>().release();
param->axis_ = 0;
size_t input_size, output_size;
std::string inputPpath = "./test_data/gatherfp32_input.bin";
std::string correctOutputPath = "./test_data/gatherfp32_output.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(inputPpath.c_str(), &input_size));
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
if (param == nullptr) {
return;
}
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_gather<float>(input_data, correct_data, input_shape, indices, param, data_type, format);
}
TEST_F(TestGatherOpenCL, Axis1Fp32) {
std::vector<int> input_shape{1, 5, 4, 4};
std::vector<int> indices{1, 3};
GatherParameter *param = std::make_unique<GatherParameter>().release();
GatherParameter *param = reinterpret_cast<GatherParameter *>(malloc(sizeof(GatherParameter)));
param->axis_ = 1;
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
@ -98,11 +139,12 @@ TEST_F(TestGatherOpenCL, Axis1Fp32) {
schema::Format format = schema::Format_NHWC;
test_main_gather<float>(input_data, correct_data, input_shape, indices, param, data_type, format);
}
TEST_F(TestGatherOpenCL, Axis2Int32) {
TEST_F(TestGatherOpenCL, Axis2Fp32) {
std::vector<int> input_shape{1, 5, 4, 4};
std::vector<int> indices{1, 3};
GatherParameter *param = std::make_unique<GatherParameter>().release();
param->axis_ = 1;
param->axis_ = 2;
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
@ -114,6 +156,25 @@ TEST_F(TestGatherOpenCL, Axis2Int32) {
}
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_gather<int>(input_data, correct_data, input_shape, indices, param, data_type, format);
test_main_gather<float>(input_data, correct_data, input_shape, indices, param, data_type, format);
}
TEST_F(TestGatherOpenCL, Axis3Fp32) {
std::vector<int> input_shape{1, 5, 4, 4};
std::vector<int> indices{1, 3};
GatherParameter *param = std::make_unique<GatherParameter>().release();
param->axis_ = 3;
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79};
float correct_data[] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39,
41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79};
if (param == nullptr) {
return;
}
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_gather<float>(input_data, correct_data, input_shape, indices, param, data_type, format);
}
} // namespace mindspore

Loading…
Cancel
Save