concat ops support 4input

pull/6288/head
Pengyongrong 5 years ago
parent 59a63d2566
commit 11762e59db

@ -62,7 +62,7 @@ __kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image
if (Y < input_shape0.z) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Y < (input_shape0.z + input_shape0.z)) {
} else if (Y < (input_shape0.z + input_shape1.z)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else {
@ -74,7 +74,7 @@ __kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image
if (Z < input_shape0.w) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Z < (input_shape0.w + input_shape0.w)) {
} else if (Z < (input_shape0.w + input_shape1.w)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else {
@ -196,3 +196,144 @@ __kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only imag
}
}
}
__kernel void Concat4input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __read_only image2d_t input3,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1,
int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) {
return;
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < input_shape0.y + input_shape1.y) {
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
((X - input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < input_shape0.y + input_shape1.y + input_shape2.y) {
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y +
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x =
((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y +
Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y);
FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Y < input_shape0.z + input_shape1.z) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Y < input_shape0.z + input_shape1.z + input_shape2.z) {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y);
FLT4 result =
READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else {
if (Z < input_shape0.w) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Z < input_shape0.w + input_shape1.w) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
(X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Z < input_shape0.w + input_shape1.w + input_shape2.w) {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y +
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y +
(Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y);
FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
}
}
__kernel void Concat4input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __read_only image2d_t input3,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2,
int4 input_shape3, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (axis == 1) {
if (X < input_shape0.y) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (X < (input_shape0.y + input_shape1.y)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
} else {
FLT4 result3 = READ_IMAGE(input3, smp_none,
(int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Y < (input_shape0.z + input_shape1.z)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
} else {
FLT4 result3 = READ_IMAGE(
input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
}
} else {
if (Z < input_shape0.w) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Z < (input_shape0.w + input_shape1.w)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
} else {
FLT4 result3 = READ_IMAGE(input3, smp_none,
(int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
}
}
}

@ -97,8 +97,10 @@ int ConcatOpenCLKernel::Init() {
kernel_name += "2input";
} else if (in_tensors_.size() == 3) {
kernel_name += "3input";
} else if (in_tensors_.size() == 4) {
kernel_name += "4input";
} else {
MS_LOG(ERROR) << " input must be 2 or 3";
MS_LOG(ERROR) << " input must be 2 3 or 4";
return RET_ERROR;
}
if (in_format == schema::Format_NC4HW4) {
@ -193,11 +195,25 @@ int ConcatOpenCLKernel::Run() {
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape3_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_);
} else if (in_tensors_.size() < 2) {
MS_LOG(ERROR) << " input sizes must >= 2 ";
return RET_ERROR;
} else if (in_tensors_.size() == 4) {
auto input3_shape = in_tensors_[2]->shape();
auto input4_shape = in_tensors_[3]->shape();
cl_int4 input_shape3_ = {input3_shape[0], input3_shape[1], input3_shape[2], UP_DIV(input3_shape[3], C4NUM)};
cl_int4 input_shape4_ = {input4_shape[0], input4_shape[1], input4_shape[2], UP_DIV(input4_shape[3], C4NUM)};
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->MutableData());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->MutableData());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->MutableData());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[3]->MutableData());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->MutableData());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape1_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape2_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape3_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape4_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_);
} else {
MS_LOG(ERROR) << " only support inputs <= 3 ";
MS_LOG(ERROR) << " input sizes must 2 or 3 or 4";
return RET_ERROR;
}
ocl_runtime->RunKernel(kernel_, global, local, nullptr);

@ -47,22 +47,25 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
auto allocator = ocl_runtime->GetAllocator();
// get the input from .bin
size_t input1_size, input2_size, input3_size, output_size;
size_t input1_size, input2_size, input3_size, input4_size, output_size;
std::string input1Ppath = "./test_data/concatfp16_input1.bin";
std::string input2Ppath = "./test_data/concatfp16_input2.bin";
std::string input3Ppath = "./test_data/concatfp16_input3.bin";
std::string input4Ppath = "./test_data/concatfp16_input4.bin";
std::string correctOutputPath = "./test_data/concatfp16_output.bin";
auto input_data1 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto input_data3 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size));
auto input_data4 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input4Ppath.c_str(), &input4_size));
auto correctOutput =
reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
MS_LOG(INFO) << " init tensors ";
constexpr int INPUT_NUM = 2;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 19, 19, 96},
std::vector<int>{1, 19, 19, 96}};
std::vector<int> output_shape = {1, 19, 19, 192};
constexpr int INPUT_NUM = 4;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {
std::vector<int>{1, 19, 19, 96}, std::vector<int>{1, 19, 19, 96}, std::vector<int>{1, 19, 19, 96},
std::vector<int>{1, 19, 19, 96}};
std::vector<int> output_shape = {1, 76, 19, 96};
auto data_type = kNumberTypeFloat16;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
std::vector<lite::Tensor *> inputs;
@ -97,7 +100,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
}
return;
}
param->axis_ = 3;
param->axis_ = 1;
auto *concat_kernel =
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (concat_kernel == nullptr) {
@ -141,8 +144,13 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
memcpy(inputs[0]->MutableData(), input_data1, input1_size);
memcpy(inputs[1]->MutableData(), input_data2, input2_size);
memcpy(inputs[2]->MutableData(), input_data3, input3_size);
} else if (inputs.size() == 4) {
memcpy(inputs[0]->MutableData(), input_data1, input1_size);
memcpy(inputs[1]->MutableData(), input_data2, input2_size);
memcpy(inputs[2]->MutableData(), input_data3, input3_size);
memcpy(inputs[3]->MutableData(), input_data4, input4_size);
} else {
MS_LOG(ERROR) << " input size must be 2 or 3";
MS_LOG(ERROR) << " input size must be 2 or 3 or 4";
}
std::cout << "==================output data================" << std::endl;

Loading…
Cancel
Save