|
|
|
@ -157,8 +157,11 @@ struct AdamFunctor<T, CPUAdam> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T, typename Flavour>
|
|
|
|
|
struct SparseAdamFunctor;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SparseAdamFunctor {
|
|
|
|
|
struct SparseAdamFunctor<T, GPUAdam> {
|
|
|
|
|
T beta1_;
|
|
|
|
|
T beta2_;
|
|
|
|
|
T epsilon_;
|
|
|
|
@ -236,6 +239,106 @@ struct SparseAdamFunctor {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SparseAdamFunctor<T, CPUAdam> {
|
|
|
|
|
T beta1_;
|
|
|
|
|
T beta2_;
|
|
|
|
|
T epsilon_;
|
|
|
|
|
|
|
|
|
|
const T* beta1_pow_;
|
|
|
|
|
const T* beta2_pow_;
|
|
|
|
|
const T* moment1_;
|
|
|
|
|
T* moment1_out_;
|
|
|
|
|
const T* moment2_;
|
|
|
|
|
T* moment2_out_;
|
|
|
|
|
const T* lr_;
|
|
|
|
|
const T* grad_;
|
|
|
|
|
const T* param_;
|
|
|
|
|
T* param_out_;
|
|
|
|
|
|
|
|
|
|
const int64_t* rows_;
|
|
|
|
|
int64_t row_numel_;
|
|
|
|
|
int64_t row_count_;
|
|
|
|
|
|
|
|
|
|
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
|
|
|
|
|
const T* beta2_pow, const T* mom1, T* mom1_out,
|
|
|
|
|
const T* mom2, T* mom2_out, const T* lr, const T* grad,
|
|
|
|
|
const T* param, T* param_out, const int64_t* rows,
|
|
|
|
|
int64_t row_numel, int64_t row_count, bool lazy_mode)
|
|
|
|
|
: beta1_(beta1),
|
|
|
|
|
beta2_(beta2),
|
|
|
|
|
epsilon_(epsilon),
|
|
|
|
|
beta1_pow_(beta1_pow),
|
|
|
|
|
beta2_pow_(beta2_pow),
|
|
|
|
|
moment1_(mom1),
|
|
|
|
|
moment1_out_(mom1_out),
|
|
|
|
|
moment2_(mom2),
|
|
|
|
|
moment2_out_(mom2_out),
|
|
|
|
|
lr_(lr),
|
|
|
|
|
grad_(grad),
|
|
|
|
|
param_(param),
|
|
|
|
|
param_out_(param_out),
|
|
|
|
|
rows_(rows),
|
|
|
|
|
row_numel_(row_numel),
|
|
|
|
|
row_count_(row_count) {}
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void adam_update(size_t i, T g) const {
|
|
|
|
|
// The following code is the same as dense
|
|
|
|
|
T mom1 = moment1_[i];
|
|
|
|
|
T mom2 = moment2_[i];
|
|
|
|
|
T lr = *lr_;
|
|
|
|
|
T beta1_pow = *beta1_pow_;
|
|
|
|
|
T beta2_pow = *beta2_pow_;
|
|
|
|
|
T p = param_[i];
|
|
|
|
|
|
|
|
|
|
// Calculation
|
|
|
|
|
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
|
|
|
|
|
|
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
|
|
|
|
|
|
// Write back to global memory
|
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
|
|
moment2_out_[i] = mom2;
|
|
|
|
|
param_out_[i] = p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void operator()(size_t numel) const {
|
|
|
|
|
// lr could be reuse
|
|
|
|
|
T lr = *lr_;
|
|
|
|
|
T beta1_pow = *beta1_pow_;
|
|
|
|
|
T beta2_pow = *beta2_pow_;
|
|
|
|
|
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
|
|
|
|
|
size_t row_count = numel / row_numel_;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0U, j = 0U; i != row_count; ++i) {
|
|
|
|
|
if (i == *(rows_ + j)) {
|
|
|
|
|
for (size_t k = 0U; k != row_numel_; ++k) {
|
|
|
|
|
T g = grad_[j * row_numel_ + k];
|
|
|
|
|
adam_update(i * row_numel_ + k, g);
|
|
|
|
|
}
|
|
|
|
|
++j;
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t k = 0U; k != row_numel_; ++k) {
|
|
|
|
|
T mom1 = moment1_[i * row_numel_ + k];
|
|
|
|
|
T mom2 = moment2_[i * row_numel_ + k];
|
|
|
|
|
T p = param_[i * row_numel_ + k];
|
|
|
|
|
|
|
|
|
|
mom1 = beta1_ * mom1;
|
|
|
|
|
mom2 = beta2_ * mom2;
|
|
|
|
|
|
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
|
// Write back to global memory
|
|
|
|
|
moment1_out_[i * row_numel_ + k] = mom1;
|
|
|
|
|
moment2_out_[i * row_numel_ + k] = mom2;
|
|
|
|
|
param_out_[i * row_numel_ + k] = p;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -331,7 +434,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<framework::SelectedRows>();
|
|
|
|
|
merge_func(ctx.template device_context<DeviceContext>(), grad,
|
|
|
|
|
grad_merge_var);
|
|
|
|
|
grad_merge_var, true);
|
|
|
|
|
grad_merge_ptr = grad_merge_var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -347,13 +450,13 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
rows = grad_merge.rows().data();
|
|
|
|
|
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA)
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
|
|
|
|
|
|
|
|
|
|
SparseAdamFunctor<T> functor(
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
SparseAdamFunctor<T, CPUAdam> functor(
|
|
|
|
|
beta1, beta2, epsilon, beta1_pow.template data<T>(),
|
|
|
|
|
beta2_pow.template data<T>(), mom1.template data<T>(),
|
|
|
|
|
mom1_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
@ -362,8 +465,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
|
grad_merge.rows().size(), lazy_mode);
|
|
|
|
|
VLOG(3) << "lazy_mode :" << lazy_mode;
|
|
|
|
|
if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
|
|
|
|
|
if (lazy_mode) {
|
|
|
|
|
size_t row_count = grad_merge.rows().size();
|
|
|
|
|
std::vector<int64_t> cpu_rows(grad_merge.rows());
|
|
|
|
|
for (size_t row_index = 0; row_index < row_count; ++row_index) {
|
|
|
|
@ -373,6 +476,20 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
functor(param.numel());
|
|
|
|
|
}
|
|
|
|
|
} else if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
SparseAdamFunctor<T, GPUAdam> functor(
|
|
|
|
|
beta1, beta2, epsilon, beta1_pow.template data<T>(),
|
|
|
|
|
beta2_pow.template data<T>(), mom1.template data<T>(),
|
|
|
|
|
mom1_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
mom2.template data<T>(),
|
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
|
grad_merge.rows().size(), lazy_mode);
|
|
|
|
|
|
|
|
|
|
// FIXME(minqiyang): remove BinarySearch in GPU later
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
static_cast<const DeviceContext&>(ctx.device_context()),
|
|
|
|
|
param.numel());
|
|
|
|
|