|
|
|
@ -490,9 +490,17 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
<< FLAGS_inner_op_parallelism
|
|
|
|
|
<< " min_param_size_to_use_multithread="
|
|
|
|
|
<< FLAGS_min_param_size_to_use_multithread;
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
FLAGS_inner_op_parallelism, 8,
|
|
|
|
|
"FLAGS_inner_op_parallelism should not be larger then 8");
|
|
|
|
|
auto& grad_rows = grad_merge.rows();
|
|
|
|
|
std::unordered_map<size_t, int> row_id_to_grad_row_offset;
|
|
|
|
|
size_t param_row_count = param.numel() / row_numel;
|
|
|
|
|
if (param_row_count < 1000) {
|
|
|
|
|
LOG(WARNING) << "param_row_count should be larger then 1000 to use "
|
|
|
|
|
"multi thread, currently "
|
|
|
|
|
<< param_row_count;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < param_row_count; ++i) {
|
|
|
|
|
row_id_to_grad_row_offset[i] = -1;
|
|
|
|
|
}
|
|
|
|
@ -501,10 +509,13 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::future<void>> fs;
|
|
|
|
|
int64_t line_in_each_thread =
|
|
|
|
|
param_row_count / FLAGS_inner_op_parallelism;
|
|
|
|
|
param_row_count / FLAGS_inner_op_parallelism + 1;
|
|
|
|
|
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
|
|
|
|
|
int64_t start = i * line_in_each_thread;
|
|
|
|
|
int64_t end = (i + 1) * line_in_each_thread;
|
|
|
|
|
if (start >= param_row_count) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (end > param_row_count) {
|
|
|
|
|
end = param_row_count;
|
|
|
|
|
}
|
|
|
|
|