fix errors in multiplex_op

scopeFix
fengjiayi 7 years ago
parent 2e617334eb
commit baa9f50da5

@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();
platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();

Loading…
Cancel
Save