|
|
|
@ -473,10 +473,19 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
|
grad_merge.rows().size(), lazy_mode);
|
|
|
|
|
// multi thread speedup
|
|
|
|
|
if (FLAGS_inner_op_parallelism > 1 &&
|
|
|
|
|
FLAGS_min_param_size_to_use_multithread > 0 &&
|
|
|
|
|
param.numel() > FLAGS_min_param_size_to_use_multithread) {
|
|
|
|
|
if (lazy_mode) {
|
|
|
|
|
VLOG(3) << "run cpu lazy mode";
|
|
|
|
|
size_t row_count = grad_merge.rows().size();
|
|
|
|
|
std::vector<int64_t> cpu_rows(grad_merge.rows());
|
|
|
|
|
for (size_t row_index = 0; row_index < row_count; ++row_index) {
|
|
|
|
|
for (size_t offset = 0; offset < row_numel; ++offset) {
|
|
|
|
|
size_t i = cpu_rows[row_index] * row_numel + offset;
|
|
|
|
|
functor.adam_update(i, grad_data[row_index * row_numel + offset]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (FLAGS_inner_op_parallelism > 1 &&
|
|
|
|
|
FLAGS_min_param_size_to_use_multithread > 0 &&
|
|
|
|
|
param.numel() > FLAGS_min_param_size_to_use_multithread) {
|
|
|
|
|
VLOG(3) << "use multi thread, inner_op_parallelism="
|
|
|
|
|
<< FLAGS_inner_op_parallelism
|
|
|
|
|
<< " min_param_size_to_use_multithread="
|
|
|
|
@ -508,20 +517,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
|
|
|
|
|
} else {
|
|
|
|
|
if (lazy_mode) {
|
|
|
|
|
VLOG(3) << "run cpu lazy mode";
|
|
|
|
|
size_t row_count = grad_merge.rows().size();
|
|
|
|
|
std::vector<int64_t> cpu_rows(grad_merge.rows());
|
|
|
|
|
for (size_t row_index = 0; row_index < row_count; ++row_index) {
|
|
|
|
|
for (size_t offset = 0; offset < row_numel; ++offset) {
|
|
|
|
|
size_t i = cpu_rows[row_index] * row_numel + offset;
|
|
|
|
|
functor.adam_update(i,
|
|
|
|
|
grad_data[row_index * row_numel + offset]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
functor(param.numel());
|
|
|
|
|
}
|
|
|
|
|
functor(param.numel());
|
|
|
|
|
}
|
|
|
|
|
} else if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
SparseAdamFunctor<T, GPUAdam> functor(
|
|
|
|
|