|
|
@ -174,12 +174,13 @@ struct SparseAdamFunctor {
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t* rows_;
|
|
|
|
const int64_t* rows_;
|
|
|
|
int64_t row_numel_;
|
|
|
|
int64_t row_numel_;
|
|
|
|
|
|
|
|
int64_t row_count_;
|
|
|
|
|
|
|
|
|
|
|
|
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
|
|
|
|
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
|
|
|
|
const T* beta2_pow, const T* mom1, T* mom1_out,
|
|
|
|
const T* beta2_pow, const T* mom1, T* mom1_out,
|
|
|
|
const T* mom2, T* mom2_out, const T* lr, const T* grad,
|
|
|
|
const T* mom2, T* mom2_out, const T* lr, const T* grad,
|
|
|
|
const T* param, T* param_out, const int64_t* rows,
|
|
|
|
const T* param, T* param_out, const int64_t* rows,
|
|
|
|
int64_t row_numel)
|
|
|
|
int64_t row_numel, int64_t row_count)
|
|
|
|
: beta1_(beta1),
|
|
|
|
: beta1_(beta1),
|
|
|
|
beta2_(beta2),
|
|
|
|
beta2_(beta2),
|
|
|
|
epsilon_(epsilon),
|
|
|
|
epsilon_(epsilon),
|
|
|
@ -194,28 +195,47 @@ struct SparseAdamFunctor {
|
|
|
|
param_(param),
|
|
|
|
param_(param),
|
|
|
|
param_out_(param_out),
|
|
|
|
param_out_(param_out),
|
|
|
|
rows_(rows),
|
|
|
|
rows_(rows),
|
|
|
|
row_numel_(row_numel) {}
|
|
|
|
row_numel_(row_numel),
|
|
|
|
|
|
|
|
row_count_(row_count) {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
|
|
|
|
|
|
|
|
int64_t beg = 0, end = row_count_ - 1;
|
|
|
|
|
|
|
|
while (beg <= end) {
|
|
|
|
|
|
|
|
auto mid = ((beg + end) >> 1);
|
|
|
|
|
|
|
|
if (rows_[mid] == row)
|
|
|
|
|
|
|
|
return mid;
|
|
|
|
|
|
|
|
else if (rows_[mid] < row)
|
|
|
|
|
|
|
|
beg = mid + 1;
|
|
|
|
|
|
|
|
else
|
|
|
|
|
|
|
|
end = mid - 1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void operator()(size_t i) const {
|
|
|
|
inline HOSTDEVICE void operator()(size_t i) const {
|
|
|
|
|
|
|
|
int64_t row = i / row_numel_;
|
|
|
|
|
|
|
|
auto row_idx = BinarySearchInRows(row);
|
|
|
|
|
|
|
|
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// The following code is the same as dense
|
|
|
|
|
|
|
|
T mom1 = moment1_[i];
|
|
|
|
|
|
|
|
T mom2 = moment2_[i];
|
|
|
|
|
|
|
|
T lr = *lr_;
|
|
|
|
T beta1_pow = *beta1_pow_;
|
|
|
|
T beta1_pow = *beta1_pow_;
|
|
|
|
T beta2_pow = *beta2_pow_;
|
|
|
|
T beta2_pow = *beta2_pow_;
|
|
|
|
for (int64_t j = 0; j < row_numel_; ++j) {
|
|
|
|
T p = param_[i];
|
|
|
|
T g = grad_[i * row_numel_ + j];
|
|
|
|
|
|
|
|
T mom1 = moment1_[rows_[i] * row_numel_ + j];
|
|
|
|
// Calculation
|
|
|
|
T mom2 = moment2_[rows_[i] * row_numel_ + j];
|
|
|
|
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
|
|
|
|
T lr = *lr_;
|
|
|
|
|
|
|
|
T p = param_[rows_[i] * row_numel_ + j];
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
|
|
|
|
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
// Write back to global memory
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
moment2_out_[i] = mom2;
|
|
|
|
|
|
|
|
param_out_[i] = p;
|
|
|
|
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
|
|
|
|
|
|
|
|
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
|
|
|
|
|
|
|
|
param_out_[rows_[i] * row_numel_ + j] = p;
|
|
|
|
|
|
|
|
} // for col id
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// merge duplicated rows if any.
|
|
|
|
// merge duplicated rows if any.
|
|
|
|
|
|
|
|
// The rows of grad_merge have been sorted inside MergeAdd functor
|
|
|
|
scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
auto grad_merge =
|
|
|
|
auto& grad_merge = *(ctx.scope()
|
|
|
|
merge_func(ctx.template device_context<DeviceContext>(), grad);
|
|
|
|
.NewScope()
|
|
|
|
|
|
|
|
.Var("sparse_adam_grad_merge")
|
|
|
|
|
|
|
|
->GetMutable<framework::SelectedRows>());
|
|
|
|
|
|
|
|
merge_func(ctx.template device_context<DeviceContext>(), grad,
|
|
|
|
|
|
|
|
&grad_merge);
|
|
|
|
auto& grad_tensor = grad_merge.value();
|
|
|
|
auto& grad_tensor = grad_merge.value();
|
|
|
|
const T* grad_data = grad_tensor.template data<T>();
|
|
|
|
const T* grad_data = grad_tensor.template data<T>();
|
|
|
|
int64_t* rows = nullptr;
|
|
|
|
int64_t* rows = nullptr;
|
|
|
@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
mom2.template data<T>(),
|
|
|
|
mom2.template data<T>(),
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
|
|
|
|
grad_merge.rows().size());
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
static_cast<const DeviceContext&>(ctx.device_context()),
|
|
|
|
static_cast<const DeviceContext&>(ctx.device_context()),
|
|
|
|
grad_merge.rows().size());
|
|
|
|
param.numel());
|
|
|
|
for_range(functor);
|
|
|
|
for_range(functor);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PADDLE_THROW("Variable type not supported by adam_op");
|
|
|
|
PADDLE_THROW("Variable type not supported by adam_op");
|
|
|
|