|
|
|
@ -17,31 +17,56 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/memory/memcpy.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MultiplexCPUKernel : public framework::OpKernel {
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class MultiplexKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto* out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto index = ins[0]->data<T>();
|
|
|
|
|
auto rows = ins[1]->dims()[0];
|
|
|
|
|
auto cols = ins[1]->dims()[1];
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
memcpy(out->data<T>() + i * cols, ins[k]->data<T>() + i * cols,
|
|
|
|
|
cols * sizeof(T));
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto* index = ins[0]->data<T>();
|
|
|
|
|
platform::CPUPlace place = boost::get<platform::CPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
PADDLE_ENFORCE_LT(k, ins.size(),
|
|
|
|
|
"index exceeds the number of candidate tensors.");
|
|
|
|
|
memory::Copy(place, out->data<T>() + i * cols, place,
|
|
|
|
|
ins[k]->data<T>() + i * cols, cols * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
// copy index to cpu
|
|
|
|
|
framework::Tensor index_t_cpu;
|
|
|
|
|
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
|
|
|
|
auto* index = index_t_cpu.data<T>();
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream();
|
|
|
|
|
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
PADDLE_ENFORCE_LT(k, ins.size(),
|
|
|
|
|
"index exceeds the number of candidate tensors.");
|
|
|
|
|
memory::Copy(place, out->data<T>() + i * cols, place,
|
|
|
|
|
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MultiplexGradCPUKernel : public framework::OpKernel {
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class MultiplexGradKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
@ -51,20 +76,42 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
|
|
|
|
|
for (size_t i = 1; i < d_ins.size(); i++) {
|
|
|
|
|
if (d_ins[i]) {
|
|
|
|
|
d_ins[i]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto dims = d_ins[i]->dims();
|
|
|
|
|
memset(d_ins[i]->data<T>(), 0, framework::product(dims) * sizeof(T));
|
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
|
|
|
|
|
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto index = ins[0]->data<T>();
|
|
|
|
|
auto rows = ins[1]->dims()[0];
|
|
|
|
|
auto cols = ins[1]->dims()[1];
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
if (d_ins[k]) {
|
|
|
|
|
memcpy(d_ins[k]->data<T>() + i * cols, d_out->data<T>() + i * cols,
|
|
|
|
|
cols * sizeof(T));
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto* index = ins[0]->data<T>();
|
|
|
|
|
platform::CPUPlace place = boost::get<platform::CPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
if (d_ins[k]) {
|
|
|
|
|
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
|
|
|
|
|
d_out->data<T>() + i * cols, cols * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
// copy index to cpu
|
|
|
|
|
framework::Tensor index_t_cpu;
|
|
|
|
|
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
|
|
|
|
auto* index = index_t_cpu.data<T>();
|
|
|
|
|
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream();
|
|
|
|
|
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
if (d_ins[k]) {
|
|
|
|
|
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
|
|
|
|
|
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|