|
|
|
@ -15,8 +15,8 @@ limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/framework/ddim.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/operators/strided_memcpy.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -28,17 +28,38 @@ class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
|
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
|
|
|
|
|
const size_t n = ins.size();
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
|
out->mutable_data<T>(place);
|
|
|
|
|
|
|
|
|
|
auto out_stride = framework::stride_numel(out->dims());
|
|
|
|
|
int64_t before = out_stride[0] / out_stride[axis];
|
|
|
|
|
int64_t out_after = out_stride[axis];
|
|
|
|
|
|
|
|
|
|
size_t output_offset = 0;
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto out_stride = framework::stride(out->dims());
|
|
|
|
|
for (size_t i = 0; i < n; i++) {
|
|
|
|
|
auto& in = ins[i];
|
|
|
|
|
auto axis_dim = in->dims()[axis];
|
|
|
|
|
auto in_stride = framework::stride(in->dims());
|
|
|
|
|
StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
|
|
|
|
|
in->dims(), out_stride, out->data<T>() + output_offset);
|
|
|
|
|
output_offset += axis_dim * in_stride[axis];
|
|
|
|
|
for (auto* in : ins) {
|
|
|
|
|
auto in_stride = framework::stride_numel(in->dims());
|
|
|
|
|
int64_t in_after = in_stride[axis];
|
|
|
|
|
for (int64_t i = 0; i < before; ++i) {
|
|
|
|
|
if (platform::is_cpu_place(place)) {
|
|
|
|
|
auto& cpu_place = boost::get<platform::CPUPlace>(place);
|
|
|
|
|
memory::Copy(
|
|
|
|
|
cpu_place, out->data<T>() + output_offset + i * out_after,
|
|
|
|
|
cpu_place, in->data<T>() + i * in_after, sizeof(T) * in_after);
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
|
|
|
|
|
auto& cuda_ctx =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
|
|
|
|
|
memory::Copy(gpu_place, out->data<T>() +
|
|
|
|
|
output_offset + i * out_after,
|
|
|
|
|
gpu_place, in->data<T>() + i * in_after,
|
|
|
|
|
sizeof(T) * in_after, cuda_ctx.stream()));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Paddle is not compiled with GPU");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
output_offset += in_after;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -50,17 +71,37 @@ class ConcatGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
|
|
|
|
|
const size_t n = outs.size();
|
|
|
|
|
size_t input_offset = 0;
|
|
|
|
|
auto in_stride = framework::stride(in->dims());
|
|
|
|
|
for (size_t i = 0; i < n; i++) {
|
|
|
|
|
auto& out = outs[i];
|
|
|
|
|
auto in_stride = framework::stride_numel(in->dims());
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
|
|
|
|
|
|
// numel before the specified axis
|
|
|
|
|
int64_t before = in_stride[0] / in_stride[axis];
|
|
|
|
|
int64_t in_after = in_stride[axis];
|
|
|
|
|
for (auto& out : outs) {
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
size_t axis_dim = out->dims()[axis];
|
|
|
|
|
auto out_stride = framework::stride(out->dims());
|
|
|
|
|
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
|
|
|
|
|
in_stride, out->dims(), out_stride, out->data<T>());
|
|
|
|
|
input_offset += axis_dim * in_stride[axis];
|
|
|
|
|
auto out_stride = framework::stride_numel(out->dims());
|
|
|
|
|
int64_t out_after = out_stride[axis];
|
|
|
|
|
for (int64_t i = 0; i < before; ++i) {
|
|
|
|
|
if (platform::is_cpu_place(place)) {
|
|
|
|
|
auto& cpu_place = boost::get<platform::CPUPlace>(place);
|
|
|
|
|
memory::Copy(cpu_place, out->data<T>() + i * out_after, cpu_place,
|
|
|
|
|
in->data<T>() + input_offset + i * in_after,
|
|
|
|
|
sizeof(T) * out_after);
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
|
|
|
|
|
auto& cuda_ctx =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
|
|
|
|
|
memory::Copy(gpu_place, out->data<T>() + i * out_after, gpu_place,
|
|
|
|
|
in->data<T>() + input_offset + i * in_after,
|
|
|
|
|
sizeof(T) * out_after, cuda_ctx.stream());
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Paddle is not compiled with GPU");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
input_offset += out_after;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|