|
|
|
@ -48,16 +48,16 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
// computation
|
|
|
|
|
for (int k = 0; k < out_rows; ++k) {
|
|
|
|
|
T* dst_ptr = output->data<T>() + k * out_cols;
|
|
|
|
|
auto output_data = output->data<T>();
|
|
|
|
|
int col_idx = 0;
|
|
|
|
|
for (int j = 0; j < num; ++j) {
|
|
|
|
|
int col_len = input_cols[j];
|
|
|
|
|
const T* src_prt = input[j].data<T>() + k * col_len;
|
|
|
|
|
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
|
|
|
|
|
sizeof(T) * col_len);
|
|
|
|
|
col_idx += col_len;
|
|
|
|
|
auto input_data = input[j].data<T>();
|
|
|
|
|
for (int k = 0; k < out_rows; ++k) {
|
|
|
|
|
memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
|
|
|
|
|
input_data + k * col_len, sizeof(T) * col_len);
|
|
|
|
|
}
|
|
|
|
|
col_idx += col_len;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|