!5996 [MS][LITE][DEVELOP]Concat ops support MutiDimension

Merge pull request !5996 from pengyongrong/concatAnyDimension
pull/5996/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0bac67e819

@ -10,15 +10,7 @@ __kernel void Concat2input_NHWC4(__read_only image2d_t input0, __read_only image
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (axis == 0) {
if (X < input_shape0.x * input_shape0.y) {
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
} else {
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.x * input_shape0.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
}
} else if (axis == 1) {
if (axis == 1) {
if (X < input_shape0.y) {
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
@ -54,21 +46,7 @@ __kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (axis == 0) {
if (X < input_shape0.x * input_shape0.y) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) {
FLT4 result1 =
READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.x * input_shape0.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else {
FLT4 result2 = READ_IMAGE(
input2, smp_none,
(int2)((Y)*input_shape2.w + Z, (X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
}
} else if (axis == 1) {
if (axis == 1) {
if (X < input_shape0.y) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
@ -121,18 +99,7 @@ __kernel void Concat2input_NC4HW4(__read_only image2d_t input0, __read_only imag
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 0) {
if (X < (input_shape0.x * input_shape0.y)) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y +
Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 1) {
if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
@ -181,25 +148,7 @@ __kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only imag
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 0) {
if (X < (input_shape0.x * input_shape0.y)) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) {
in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y +
Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) / input_shape2.y) *
input_shape2.w * input_shape2.y +
Z * input_shape2.y +
(X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) % input_shape2.y;
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 1) {
if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));

@ -59,10 +59,10 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_float(__global float4 *src_data, __wr
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) {
return;
}
int offset = (X * shape.z + Y) * shape.w + Z * 4;
int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4;
__global float *src_addr = (__global float *)src_data;
src_addr += offset;
FLT4 data = (FLT4)(0.f);
@ -79,17 +79,18 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_float(__global float4 *src_data, __wr
data.z = (FLT)src_addr[2];
}
}
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data);
int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y;
WRITE_IMAGE(dst_data, (int2)(Y, pos_ix), data);
}
__kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) {
return;
}
int offset = (X * shape.z + Y) * shape.w + Z * 4;
int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4;
__global half *src_addr = (__global half *)src_data;
src_addr += offset;
FLT4 data = (FLT4)(0.f);
@ -106,7 +107,8 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __writ
data.z = (FLT)src_addr[2];
}
}
WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data);
int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y;
WRITE_IMAGE(dst_data, (int2)(Y, pos_ix), data);
}
__kernel void to_format_NHWC4_to_NHWC4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
@ -227,11 +229,12 @@ __kernel void to_format_NC4HW4_to_NHWC_BUF_float(__read_only image2d_t src_data,
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) {
return;
}
float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X)));
int offset = (X * shape.z + Y) * shape.w + Z * 4;
int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y;
float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, pos_ix)));
int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4;
__global float *dst_addr = (__global float *)dst_data;
dst_addr += offset;
if ((Z + 1) * 4 <= shape.w) {
@ -253,11 +256,12 @@ __kernel void to_format_NC4HW4_to_NHWC_BUF_half(__read_only image2d_t src_data,
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= size.x || Y >= size.y || Z >= size.z) {
if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) {
return;
}
half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X)));
int offset = (X * shape.z + Y) * shape.w + Z * 4;
int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y;
half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, pos_ix)));
int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4;
__global half *dst_addr = (__global half *)dst_data;
dst_addr += offset;
if ((Z + 1) * 4 <= shape.w) {

@ -49,6 +49,26 @@ int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size)
*img_size = vec;
return RET_OK;
}
int ConcatOpenCLKernel::RunAxis0() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto allocator_ = ocl_runtime->GetAllocator();
std::vector<size_t> img_size;
auto dst_data = out_tensors_[0]->MutableData();
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
for (int i = 0; i < in_tensors_.size(); i++) {
auto src_data = in_tensors_[i]->MutableData();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
cl::Image2D *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region);
dst_origin[1] += region[1];
}
return RET_OK;
}
int ConcatOpenCLKernel::Init() {
if (in_tensors_[0]->shape().size() != 4) {
MS_LOG(ERROR) << " only support dim = 4 ";
@ -98,6 +118,19 @@ int ConcatOpenCLKernel::Init() {
int ConcatOpenCLKernel::ReSize() { return RET_OK; }
int ConcatOpenCLKernel::GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape) {
std::vector<int> temp_sum = {0, 0, 0, 0};
for (int i = 0; i < in_tensors_.size(); ++i) {
auto temp = in_tensors_[i]->shape();
for (int j = 0; j < temp.size(); ++j) {
in_shape->push_back(temp[j]);
temp_sum.at(j) += temp[j];
sum_shape->push_back(temp_sum.at(j));
}
}
return RET_OK;
}
int ConcatGetBiggestDividerWithPriority(int number, int max_divider) {
if (number % 8 == 0 && max_divider >= 8) {
return number / 8;
@ -133,6 +166,9 @@ void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *
int ConcatOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto param = reinterpret_cast<ConcatParameter *>(this->op_parameter_);
if (param->axis_ == 0) {
return RunAxis0();
}
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto input1_shape = in_tensors_[0]->shape();
@ -151,6 +187,7 @@ int ConcatOpenCLKernel::Run() {
std::vector<size_t> local = {1, 1, 1}; // init local
std::vector<size_t> global = {OH, OW, OC};
ConcatGetWorkGroup(global, &local, max_global[0]);
GetSumShape(&sum_shape, &in_shape);
int arg_cn = 0;
if (in_tensors_.size() == 2) {

@ -38,10 +38,17 @@ class ConcatOpenCLKernel : public OpenCLKernel {
int ReSize() override;
int Run() override;
int RunAxis0();
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
int GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape);
private:
cl::Kernel kernel_;
std::vector<int> sum_shape;
std::vector<int> in_shape;
};
} // namespace mindspore::kernel

@ -62,19 +62,19 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
constexpr int INPUT_NUM = 2;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 19, 19, 96},
std::vector<int>{1, 19, 19, 96}};
std::vector<int> output_shape = {2, 19, 19, 96};
std::vector<int> output_shape = {1, 19, 19, 192};
auto data_type = kNumberTypeFloat16;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
std::vector<lite::Tensor *> inputs;
for (auto &shape : input_shapes) {
auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC4, tensor_type);
auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type);
inputs.push_back(input_temp);
if (input_temp == nullptr) {
MS_LOG(INFO) << " new input_tensor failed ";
return;
}
}
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type);
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(INFO) << " new output_tensor failed ";
for (auto tensor : inputs) {
@ -97,7 +97,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
}
return;
}
param->axis_ = 0;
param->axis_ = 3;
auto *concat_kernel =
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (concat_kernel == nullptr) {
@ -111,6 +111,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
delete param;
return;
}
concat_kernel->SetFormatType(schema::Format_NC4HW4);
concat_kernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
@ -229,8 +230,9 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
delete param;
return;
}
concat_kernel->SetFormatType(schema::Format_NC4HW4);
concat_kernel->Init();
// to do allocate memory for inputs and outputs
// to do allocate memory for inputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}

Loading…
Cancel
Save