|
|
|
@ -26,7 +26,7 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
auto in_list = context.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto* trainer_id_t = context.Input<framework::Tensor>("TrainerId");
|
|
|
|
|
int64_t trainer_id;
|
|
|
|
|
int64_t trainer_id = 0;
|
|
|
|
|
auto* trainer_id_data = trainer_id_t->data<int64_t>();
|
|
|
|
|
if (platform::is_gpu_place(context.GetPlace())) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -38,7 +38,6 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else {
|
|
|
|
|
trainer_id = *trainer_id_data;
|
|
|
|
|
}
|
|
|
|
|
printf("after get trainer_id %lu\n", trainer_id);
|
|
|
|
|
PADDLE_ENFORCE_LT(trainer_id, in_list.size());
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
out->ShareDataWith(*(in_list[trainer_id]));
|
|
|
|
|