fix compile

mobile_baidu
typhoonzero 7 years ago
parent fd7ed3b9c6
commit 579c92abc3

@ -71,7 +71,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.device_context().stream();
auto stream = ctx.cuda_device_context().stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]);

Loading…
Cancel
Save