|
|
|
@ -71,7 +71,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
|
|
|
|
|
auto* index = index_t_cpu.data<int32_t>();
|
|
|
|
|
|
|
|
|
|
auto stream = ctx.device_context().stream();
|
|
|
|
|
auto stream = ctx.cuda_device_context().stream();
|
|
|
|
|
Place place = boost::get<Place>(ctx.GetPlace());
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
size_t k = static_cast<size_t>(index[i]);
|
|
|
|
|