|
|
|
@ -18,19 +18,20 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class MultiplexGPUKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto* out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
|
|
|
|
|
auto ins = ctx.MultiInput<Tensor>("X");
|
|
|
|
|
auto* out = ctx.Output<Tensor>("Out");
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto rows = ins[1]->dims()[0];
|
|
|
|
|
auto cols = ins[1]->dims()[1];
|
|
|
|
|
// copy index to cpu
|
|
|
|
|
framework::Tensor index_t_cpu;
|
|
|
|
|
Tensor index_t_cpu;
|
|
|
|
|
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
|
|
|
|
auto* index = index_t_cpu.data<T>();
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
@ -38,7 +39,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
|
|
|
|
|
.stream();
|
|
|
|
|
Place place = boost::get<Place>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
size_t k = (size_t)index[i] + 1;
|
|
|
|
|
PADDLE_ENFORCE_LT(k, ins.size(),
|
|
|
|
|
"index exceeds the number of candidate tensors.");
|
|
|
|
|
memory::Copy(place, out->data<T>() + i * cols, place,
|
|
|
|
@ -51,10 +52,9 @@ template <typename Place, typename T>
|
|
|
|
|
class MultiplexGradGPUKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto d_ins =
|
|
|
|
|
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto ins = ctx.MultiInput<Tensor>("X");
|
|
|
|
|
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
for (size_t i = 1; i < d_ins.size(); i++) {
|
|
|
|
|
if (d_ins[i]) {
|
|
|
|
|
d_ins[i]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
@ -66,7 +66,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
|
|
|
|
|
auto rows = ins[1]->dims()[0];
|
|
|
|
|
auto cols = ins[1]->dims()[1];
|
|
|
|
|
// copy index to cpu
|
|
|
|
|
framework::Tensor index_t_cpu;
|
|
|
|
|
Tensor index_t_cpu;
|
|
|
|
|
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
|
|
|
|
auto* index = index_t_cpu.data<T>();
|
|
|
|
|
|
|
|
|
@ -75,7 +75,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
|
|
|
|
|
.stream();
|
|
|
|
|
Place place = boost::get<Place>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
size_t k = (size_t)index[i] + 1;
|
|
|
|
|
if (d_ins[k]) {
|
|
|
|
|
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
|
|
|
|
|
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
|
|
|
|
|