|
|
@ -177,13 +177,13 @@ struct SparseAdamFunctor {
|
|
|
|
const int64_t* rows_;
|
|
|
|
const int64_t* rows_;
|
|
|
|
int64_t row_numel_;
|
|
|
|
int64_t row_numel_;
|
|
|
|
int64_t row_count_;
|
|
|
|
int64_t row_count_;
|
|
|
|
bool sparse_mode_;
|
|
|
|
bool lazy_mode_;
|
|
|
|
|
|
|
|
|
|
|
|
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
|
|
|
|
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
|
|
|
|
const T* beta2_pow, const T* mom1, T* mom1_out,
|
|
|
|
const T* beta2_pow, const T* mom1, T* mom1_out,
|
|
|
|
const T* mom2, T* mom2_out, const T* lr, const T* grad,
|
|
|
|
const T* mom2, T* mom2_out, const T* lr, const T* grad,
|
|
|
|
const T* param, T* param_out, const int64_t* rows,
|
|
|
|
const T* param, T* param_out, const int64_t* rows,
|
|
|
|
int64_t row_numel, int64_t row_count, bool sparse_mode)
|
|
|
|
int64_t row_numel, int64_t row_count, bool lazy_mode)
|
|
|
|
: beta1_(beta1),
|
|
|
|
: beta1_(beta1),
|
|
|
|
beta2_(beta2),
|
|
|
|
beta2_(beta2),
|
|
|
|
epsilon_(epsilon),
|
|
|
|
epsilon_(epsilon),
|
|
|
@ -200,7 +200,7 @@ struct SparseAdamFunctor {
|
|
|
|
rows_(rows),
|
|
|
|
rows_(rows),
|
|
|
|
row_numel_(row_numel),
|
|
|
|
row_numel_(row_numel),
|
|
|
|
row_count_(row_count),
|
|
|
|
row_count_(row_count),
|
|
|
|
sparse_mode_(sparse_mode) {}
|
|
|
|
lazy_mode_(lazy_mode) {}
|
|
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void adam_update(size_t i, T g) const {
|
|
|
|
inline HOSTDEVICE void adam_update(size_t i, T g) const {
|
|
|
|
// The following code is the same as dense
|
|
|
|
// The following code is the same as dense
|
|
|
@ -245,7 +245,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
using paddle::framework::LoDTensor;
|
|
|
|
using paddle::framework::LoDTensor;
|
|
|
|
using paddle::operators::detail::Ref;
|
|
|
|
using paddle::operators::detail::Ref;
|
|
|
|
|
|
|
|
|
|
|
|
bool sparse_mode = ctx.Attr<bool>("sparse_mode");
|
|
|
|
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
|
|
|
|
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
|
|
|
|
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
|
|
|
|
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
|
|
|
@ -357,8 +357,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
grad_merge.rows().size(), sparse_mode);
|
|
|
|
grad_merge.rows().size(), lazy_mode);
|
|
|
|
if (sparse_mode) {
|
|
|
|
if (lazy_mode) {
|
|
|
|
size_t row_count = grad_merge.rows().size();
|
|
|
|
size_t row_count = grad_merge.rows().size();
|
|
|
|
for (size_t row_index = 0; row_index < row_count; ++row_index) {
|
|
|
|
for (size_t row_index = 0; row_index < row_count; ++row_index) {
|
|
|
|
for (size_t offset = 0; offset < row_numel; ++offset) {
|
|
|
|
for (size_t offset = 0; offset < row_numel; ++offset) {
|
|
|
|