|
|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <math.h> // for sqrt in CPU and CUDA
|
|
|
|
|
#include <Eigen/Dense>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/threadpool.h"
|
|
|
|
@ -311,17 +312,17 @@ struct SparseAdamFunctor<T, CPUAdam> {
|
|
|
|
|
T beta1_pow = *beta1_pow_;
|
|
|
|
|
T beta2_pow = *beta2_pow_;
|
|
|
|
|
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
|
|
|
|
|
size_t row_count = numel / row_numel_;
|
|
|
|
|
int64_t row_count = static_cast<int64_t>(numel / row_numel_);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0U, j = 0U; i != row_count; ++i) {
|
|
|
|
|
for (int64_t i = 0, j = 0; i != row_count; ++i) {
|
|
|
|
|
if (i == *(rows_ + j)) {
|
|
|
|
|
for (size_t k = 0U; k != row_numel_; ++k) {
|
|
|
|
|
for (int64_t k = 0; k != row_numel_; ++k) {
|
|
|
|
|
T g = grad_[j * row_numel_ + k];
|
|
|
|
|
adam_update(i * row_numel_ + k, g);
|
|
|
|
|
}
|
|
|
|
|
++j;
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t k = 0U; k != row_numel_; ++k) {
|
|
|
|
|
for (int64_t k = 0; k != row_numel_; ++k) {
|
|
|
|
|
T mom1 = moment1_[i * row_numel_ + k];
|
|
|
|
|
T mom2 = moment2_[i * row_numel_ + k];
|
|
|
|
|
T p = param_[i * row_numel_ + k];
|
|
|
|
@ -427,43 +428,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::SelectedRows cpu_grad_merge;
|
|
|
|
|
framework::SelectedRows tmp_grad_merge;
|
|
|
|
|
const framework::SelectedRows* grad_merge_ptr;
|
|
|
|
|
if (is_strict_sorted) {
|
|
|
|
|
grad_merge_ptr = &grad;
|
|
|
|
|
} else {
|
|
|
|
|
// merge duplicated rows if any.
|
|
|
|
|
// The rows of grad_merge have been sorted inside MergeAdd functor
|
|
|
|
|
framework::SelectedRows* grad_merge_var;
|
|
|
|
|
scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
grad_merge_var = &cpu_grad_merge;
|
|
|
|
|
} else {
|
|
|
|
|
// FIXME(qiao): GPU also need to fix this
|
|
|
|
|
grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<framework::SelectedRows>();
|
|
|
|
|
}
|
|
|
|
|
merge_func(ctx.template device_context<DeviceContext>(), grad,
|
|
|
|
|
grad_merge_var, true);
|
|
|
|
|
grad_merge_ptr = grad_merge_var;
|
|
|
|
|
&tmp_grad_merge, true);
|
|
|
|
|
grad_merge_ptr = &tmp_grad_merge;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& grad_merge = *grad_merge_ptr;
|
|
|
|
|
auto& grad_tensor = grad_merge.value();
|
|
|
|
|
const T* grad_data = grad_tensor.template data<T>();
|
|
|
|
|
const int64_t* rows = nullptr;
|
|
|
|
|
// When compiled without CUDA, the CUDAData() interface should not be
|
|
|
|
|
// provided.
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA)
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
rows = grad_merge.rows().CUDAData(ctx.GetPlace());
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
rows = grad_merge.rows().data();
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA)
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
|
|
|
|
|
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
@ -488,7 +469,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#ifndef _WIN32
|
|
|
|
|
else if (FLAGS_inner_op_parallelism > 1 &&
|
|
|
|
|
else if (FLAGS_inner_op_parallelism > 1 && // NOLINT
|
|
|
|
|
min_row_size_to_use_multithread > 0 &&
|
|
|
|
|
param.dims()[0] > min_row_size_to_use_multithread) {
|
|
|
|
|
VLOG(3) << "use multi thread, inner_op_parallelism="
|
|
|
|
@ -516,11 +497,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
|
|
|
|
|
int64_t start = i * line_in_each_thread;
|
|
|
|
|
int64_t end = (i + 1) * line_in_each_thread;
|
|
|
|
|
if (start >= param_row_count) {
|
|
|
|
|
if (start >= static_cast<int64_t>(param_row_count)) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (end > param_row_count) {
|
|
|
|
|
end = param_row_count;
|
|
|
|
|
if (end > static_cast<int64_t>(param_row_count)) {
|
|
|
|
|
end = static_cast<int64_t>(param_row_count);
|
|
|
|
|
}
|
|
|
|
|
fs.push_back(
|
|
|
|
|
framework::Async([&functor, &row_id_to_grad_row_offset,
|
|
|
|
@ -545,8 +526,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
|
|
|
|
|
}
|
|
|
|
|
#endif // !_WIN32
|
|
|
|
|
else {
|
|
|
|
|
#endif // !_WIN32
|
|
|
|
|
else { // NOLINT
|
|
|
|
|
functor(param.numel());
|
|
|
|
|
}
|
|
|
|
|
} else if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|