|
|
|
|
@ -147,21 +147,18 @@ void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *
|
|
|
|
|
local->push_back(z);
|
|
|
|
|
}
|
|
|
|
|
int ConcatOpenCLKernel::Run() {
|
|
|
|
|
MS_LOG(DEBUG) << this->Name() << " Running!";
|
|
|
|
|
auto param = reinterpret_cast<ConcatParameter *>(this->opParameter);
|
|
|
|
|
if (param->axis_ == 0) {
|
|
|
|
|
return Run_axis0();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
|
|
|
|
MS_LOG(INFO) << " judge the numbers of input vector";
|
|
|
|
|
auto input0_shape = inputs_[0]->shape();
|
|
|
|
|
auto input1_shape = inputs_[1]->shape();
|
|
|
|
|
auto input2_shape = inputs_[2]->shape();
|
|
|
|
|
auto output_shape = outputs_[0]->shape();
|
|
|
|
|
|
|
|
|
|
cl_int2 input0_shape2_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4)}; // change
|
|
|
|
|
cl_int3 input0_shape3_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4),
|
|
|
|
|
DivideRoundUp(input2_shape[3], 4)};
|
|
|
|
|
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], DivideRoundUp(output_shape[3], 4)};
|
|
|
|
|
|
|
|
|
|
uint32_t OH = output_shape[0] * output_shape[1]; // N*H
|
|
|
|
|
@ -173,14 +170,15 @@ int ConcatOpenCLKernel::Run() {
|
|
|
|
|
|
|
|
|
|
int arg_cn = 0;
|
|
|
|
|
if (inputs_.size() == 2) {
|
|
|
|
|
MS_LOG(INFO) << " SetKernelArg";
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data());
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data());
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data());
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape2_);
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_);
|
|
|
|
|
} else if (inputs_.size() == 3) {
|
|
|
|
|
MS_LOG(INFO) << " SetKernelArg";
|
|
|
|
|
auto input2_shape = inputs_[2]->shape();
|
|
|
|
|
cl_int3 input0_shape3_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4),
|
|
|
|
|
DivideRoundUp(input2_shape[3], 4)};
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data());
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data());
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data());
|
|
|
|
|
|