|
|
|
@ -23,7 +23,7 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class MultiplexKernel : public framework::OpKernel {
|
|
|
|
|
class MultiplexCPUKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
@ -33,40 +33,20 @@ class MultiplexKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
auto rows = ins[1]->dims()[0];
|
|
|
|
|
auto cols = ins[1]->dims()[1];
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto* index = ins[0]->data<T>();
|
|
|
|
|
platform::CPUPlace place = boost::get<platform::CPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
PADDLE_ENFORCE_LT(k, ins.size(),
|
|
|
|
|
"index exceeds the number of candidate tensors.");
|
|
|
|
|
memory::Copy(place, out->data<T>() + i * cols, place,
|
|
|
|
|
ins[k]->data<T>() + i * cols, cols * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
// copy index to cpu
|
|
|
|
|
framework::Tensor index_t_cpu;
|
|
|
|
|
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
|
|
|
|
auto* index = index_t_cpu.data<T>();
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream();
|
|
|
|
|
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
PADDLE_ENFORCE_LT(k, ins.size(),
|
|
|
|
|
"index exceeds the number of candidate tensors.");
|
|
|
|
|
memory::Copy(place, out->data<T>() + i * cols, place,
|
|
|
|
|
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
auto* index = ins[0]->data<T>();
|
|
|
|
|
Place place = boost::get<Place>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
PADDLE_ENFORCE_LT(k, ins.size(),
|
|
|
|
|
"index exceeds the number of candidate tensors.");
|
|
|
|
|
memory::Copy(place, out->data<T>() + i * cols, place,
|
|
|
|
|
ins[k]->data<T>() + i * cols, cols * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class MultiplexGradKernel : public framework::OpKernel {
|
|
|
|
|
class MultiplexGradCPUKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
@ -83,35 +63,14 @@ class MultiplexGradKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
auto rows = ins[1]->dims()[0];
|
|
|
|
|
auto cols = ins[1]->dims()[1];
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto* index = ins[0]->data<T>();
|
|
|
|
|
platform::CPUPlace place = boost::get<platform::CPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
if (d_ins[k]) {
|
|
|
|
|
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
|
|
|
|
|
d_out->data<T>() + i * cols, cols * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
// copy index to cpu
|
|
|
|
|
framework::Tensor index_t_cpu;
|
|
|
|
|
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
|
|
|
|
auto* index = index_t_cpu.data<T>();
|
|
|
|
|
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream();
|
|
|
|
|
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
if (d_ins[k]) {
|
|
|
|
|
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
|
|
|
|
|
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
|
|
|
|
|
}
|
|
|
|
|
auto* index = ins[0]->data<T>();
|
|
|
|
|
Place place = boost::get<Place>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int k = (int)index[i] + 1;
|
|
|
|
|
if (d_ins[k]) {
|
|
|
|
|
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
|
|
|
|
|
d_out->data<T>() + i * cols, cols * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|