enhance the forward of concat op

createGenDocLib
luotao1 7 years ago
parent 557be6fc58
commit 2b4edacca0

@ -48,16 +48,16 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation // computation
for (int k = 0; k < out_rows; ++k) { auto output_data = output->data<T>();
T* dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0; int col_idx = 0;
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
int col_len = input_cols[j]; int col_len = input_cols[j];
const T* src_prt = input[j].data<T>() + k * col_len; auto input_data = input[j].data<T>();
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, for (int k = 0; k < out_rows; ++k) {
sizeof(T) * col_len); memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
col_idx += col_len; input_data + k * col_len, sizeof(T) * col_len);
} }
col_idx += col_len;
} }
} }
}; };

Loading…
Cancel
Save