|
|
|
@ -28,17 +28,18 @@ class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
|
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
|
|
|
|
|
const size_t n = ins.size();
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
|
out->mutable_data<T>(place);
|
|
|
|
|
|
|
|
|
|
auto out_stride = framework::stride_numel(out->dims());
|
|
|
|
|
|
|
|
|
|
size_t output_offset = 0;
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto out_stride = framework::stride(out->dims());
|
|
|
|
|
for (size_t i = 0; i < n; i++) {
|
|
|
|
|
auto& in = ins[i];
|
|
|
|
|
auto axis_dim = in->dims()[axis];
|
|
|
|
|
auto in_stride = framework::stride(in->dims());
|
|
|
|
|
StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
|
|
|
|
|
in->dims(), out_stride, out->data<T>() + output_offset);
|
|
|
|
|
output_offset += axis_dim * in_stride[axis];
|
|
|
|
|
for (auto* in : ins) {
|
|
|
|
|
auto in_stride = framework::stride_numel(in->dims());
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
|
|
|
|
|
out->data<T>() + output_offset, out_stride,
|
|
|
|
|
in->data<T>(), in_stride);
|
|
|
|
|
output_offset += in_stride[axis];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -50,17 +51,16 @@ class ConcatGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
|
|
|
|
|
const size_t n = outs.size();
|
|
|
|
|
size_t input_offset = 0;
|
|
|
|
|
auto in_stride = framework::stride(in->dims());
|
|
|
|
|
for (size_t i = 0; i < n; i++) {
|
|
|
|
|
auto& out = outs[i];
|
|
|
|
|
auto in_stride = framework::stride_numel(in->dims());
|
|
|
|
|
|
|
|
|
|
for (auto& out : outs) {
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
size_t axis_dim = out->dims()[axis];
|
|
|
|
|
auto out_stride = framework::stride(out->dims());
|
|
|
|
|
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
|
|
|
|
|
in_stride, out->dims(), out_stride, out->data<T>());
|
|
|
|
|
input_offset += axis_dim * in_stride[axis];
|
|
|
|
|
auto out_stride = framework::stride_numel(out->dims());
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
|
|
|
|
|
out_stride, in->data<T>() + input_offset,
|
|
|
|
|
in_stride);
|
|
|
|
|
input_offset += out_stride[axis];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|