|
|
|
@ -37,8 +37,9 @@ class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
size_t output_offset = 0;
|
|
|
|
|
for (auto* in : ins) {
|
|
|
|
|
auto in_stride = framework::stride_numel(in->dims());
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>() + output_offset,
|
|
|
|
|
out_stride, in->data<T>(), in_stride);
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
|
|
|
|
|
out->data<T>() + output_offset, out_stride,
|
|
|
|
|
in->data<T>(), in_stride);
|
|
|
|
|
output_offset += in_stride[axis];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -57,8 +58,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (auto& out : outs) {
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto out_stride = framework::stride_numel(out->dims());
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
|
|
|
|
|
in->data<T>() + input_offset, in_stride);
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
|
|
|
|
|
out_stride, in->data<T>() + input_offset,
|
|
|
|
|
in_stride);
|
|
|
|
|
input_offset += out_stride[axis];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|