|
|
|
@ -490,9 +490,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
<< FLAGS_inner_op_parallelism
|
|
|
|
|
<< " min_param_size_to_use_multithread="
|
|
|
|
|
<< FLAGS_min_param_size_to_use_multithread;
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
FLAGS_inner_op_parallelism, 8,
|
|
|
|
|
"FLAGS_inner_op_parallelism should not be larger then 8");
|
|
|
|
|
if (FLAGS_inner_op_parallelism > 10) {
|
|
|
|
|
LOG(WARNING) << "FLAGS_inner_op_parallelism "
|
|
|
|
|
<< FLAGS_inner_op_parallelism << " is two large!";
|
|
|
|
|
}
|
|
|
|
|
auto& grad_rows = grad_merge.rows();
|
|
|
|
|
std::unordered_map<size_t, int> row_id_to_grad_row_offset;
|
|
|
|
|
size_t param_row_count = param.numel() / row_numel;
|
|
|
|
|