|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include <Eigen/Dense>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/threadpool.h"
|
|
|
|
|
#include "paddle/fluid/operators/detail/safe_ref.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/algorithm.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
@ -353,6 +354,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
using paddle::framework::LoDTensor;
|
|
|
|
|
using paddle::operators::detail::Ref;
|
|
|
|
|
|
|
|
|
|
int64_t min_row_size_to_use_multithread =
|
|
|
|
|
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
|
|
|
|
|
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
|
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
|
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
|
|
|
|
@ -473,8 +476,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
|
grad_merge.rows().size(), lazy_mode);
|
|
|
|
|
|
|
|
|
|
if (lazy_mode) {
|
|
|
|
|
VLOG(3) << "run cpu lazy mode";
|
|
|
|
|
size_t row_count = grad_merge.rows().size();
|
|
|
|
|
std::vector<int64_t> cpu_rows(grad_merge.rows());
|
|
|
|
|
for (size_t row_index = 0; row_index < row_count; ++row_index) {
|
|
|
|
@ -483,6 +486,62 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
functor.adam_update(i, grad_data[row_index * row_numel + offset]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (FLAGS_inner_op_parallelism > 1 &&
|
|
|
|
|
min_row_size_to_use_multithread > 0 &&
|
|
|
|
|
param.dims()[0] > min_row_size_to_use_multithread) {
|
|
|
|
|
VLOG(3) << "use multi thread, inner_op_parallelism="
|
|
|
|
|
<< FLAGS_inner_op_parallelism
|
|
|
|
|
<< " min_row_size_to_use_multithread="
|
|
|
|
|
<< min_row_size_to_use_multithread;
|
|
|
|
|
if (FLAGS_inner_op_parallelism > 10) {
|
|
|
|
|
VLOG(1) << "FLAGS_inner_op_parallelism "
|
|
|
|
|
<< FLAGS_inner_op_parallelism << " is two large!";
|
|
|
|
|
}
|
|
|
|
|
auto& grad_rows = grad_merge.rows();
|
|
|
|
|
std::unordered_map<size_t, int> row_id_to_grad_row_offset;
|
|
|
|
|
size_t param_row_count = param.numel() / row_numel;
|
|
|
|
|
if (param_row_count < 1000) {
|
|
|
|
|
VLOG(1) << "param_row_count should be larger then 1000 to use "
|
|
|
|
|
"multi thread, currently "
|
|
|
|
|
<< param_row_count;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < grad_rows.size(); ++i) {
|
|
|
|
|
row_id_to_grad_row_offset[grad_rows[i]] = i;
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::future<void>> fs;
|
|
|
|
|
int64_t line_in_each_thread =
|
|
|
|
|
param_row_count / FLAGS_inner_op_parallelism + 1;
|
|
|
|
|
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
|
|
|
|
|
int64_t start = i * line_in_each_thread;
|
|
|
|
|
int64_t end = (i + 1) * line_in_each_thread;
|
|
|
|
|
if (start >= param_row_count) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (end > param_row_count) {
|
|
|
|
|
end = param_row_count;
|
|
|
|
|
}
|
|
|
|
|
fs.push_back(
|
|
|
|
|
framework::Async([&functor, &row_id_to_grad_row_offset,
|
|
|
|
|
&grad_data, row_numel, start, end]() {
|
|
|
|
|
for (int64_t row_id = start; row_id < end; ++row_id) {
|
|
|
|
|
auto iter = row_id_to_grad_row_offset.find(row_id);
|
|
|
|
|
if (iter != row_id_to_grad_row_offset.end()) {
|
|
|
|
|
for (size_t row_offset = 0U; row_offset < row_numel;
|
|
|
|
|
++row_offset) {
|
|
|
|
|
functor.adam_update(
|
|
|
|
|
row_id * row_numel + row_offset,
|
|
|
|
|
grad_data[iter->second * row_numel + row_offset]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t row_offset = 0U; row_offset < row_numel;
|
|
|
|
|
++row_offset) {
|
|
|
|
|
functor.adam_update(row_id * row_numel + row_offset, 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
|
|
|
|
|
} else {
|
|
|
|
|
functor(param.numel());
|
|
|
|
|
}
|
|
|
|
|