|
|
|
@ -177,6 +177,9 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
|
|
|
|
|
out_row, out_col, output->data<T>());
|
|
|
|
|
}
|
|
|
|
|
// Wait() must be called because `inputs_data` may be destructed before
|
|
|
|
|
// kernel ends
|
|
|
|
|
context.Wait();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -252,6 +255,9 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
input.data<T>(), in_row, in_col, dev_outs_col_data,
|
|
|
|
|
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
|
|
|
|
|
}
|
|
|
|
|
// Wait() must be called because `outputs_data` may be destructed before
|
|
|
|
|
// kernel ends
|
|
|
|
|
context.Wait();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|