|
|
|
@ -486,7 +486,9 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
functor.adam_update(i, grad_data[row_index * row_numel + offset]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (FLAGS_inner_op_parallelism > 1 &&
|
|
|
|
|
}
|
|
|
|
|
#ifndef _WIN32
|
|
|
|
|
else if (FLAGS_inner_op_parallelism > 1 &&
|
|
|
|
|
min_row_size_to_use_multithread > 0 &&
|
|
|
|
|
param.dims()[0] > min_row_size_to_use_multithread) {
|
|
|
|
|
VLOG(3) << "use multi thread, inner_op_parallelism="
|
|
|
|
@ -542,7 +544,9 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
#endif // !_WIN32
|
|
|
|
|
else {
|
|
|
|
|
functor(param.numel());
|
|
|
|
|
}
|
|
|
|
|
} else if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|