|
|
|
|
@ -177,12 +177,13 @@ struct SparseAdamFunctor {
|
|
|
|
|
const int64_t* rows_;
|
|
|
|
|
int64_t row_numel_;
|
|
|
|
|
int64_t row_count_;
|
|
|
|
|
bool sparse_mode_;
|
|
|
|
|
|
|
|
|
|
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
|
|
|
|
|
const T* beta2_pow, const T* mom1, T* mom1_out,
|
|
|
|
|
const T* mom2, T* mom2_out, const T* lr, const T* grad,
|
|
|
|
|
const T* param, T* param_out, const int64_t* rows,
|
|
|
|
|
int64_t row_numel, int64_t row_count)
|
|
|
|
|
int64_t row_numel, int64_t row_count, bool sparse_mode)
|
|
|
|
|
: beta1_(beta1),
|
|
|
|
|
beta2_(beta2),
|
|
|
|
|
epsilon_(epsilon),
|
|
|
|
|
@ -198,13 +199,10 @@ struct SparseAdamFunctor {
|
|
|
|
|
param_out_(param_out),
|
|
|
|
|
rows_(rows),
|
|
|
|
|
row_numel_(row_numel),
|
|
|
|
|
row_count_(row_count) {}
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void operator()(size_t i) const {
|
|
|
|
|
auto row_idx =
|
|
|
|
|
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
|
|
|
|
|
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
|
|
|
|
|
row_count_(row_count),
|
|
|
|
|
sparse_mode_(sparse_mode) {}
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void sparse_update(size_t i, T g) const {
|
|
|
|
|
// The following code is the same as dense
|
|
|
|
|
T mom1 = moment1_[i];
|
|
|
|
|
T mom2 = moment2_[i];
|
|
|
|
|
@ -225,6 +223,13 @@ struct SparseAdamFunctor {
|
|
|
|
|
moment2_out_[i] = mom2;
|
|
|
|
|
param_out_[i] = p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void operator()(size_t i) const {
|
|
|
|
|
auto row_idx =
|
|
|
|
|
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
|
|
|
|
|
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
|
|
|
|
|
sparse_update(i, g);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
@ -240,6 +245,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
using paddle::framework::LoDTensor;
|
|
|
|
|
using paddle::operators::detail::Ref;
|
|
|
|
|
|
|
|
|
|
bool sparse_mode = ctx.Attr<bool>("sparse_mode");
|
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
|
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
|
|
|
|
|
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
|
|
|
|
|
@ -351,11 +357,22 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
|
grad_merge.rows().size());
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
static_cast<const DeviceContext&>(ctx.device_context()),
|
|
|
|
|
param.numel());
|
|
|
|
|
for_range(functor);
|
|
|
|
|
grad_merge.rows().size(), sparse_mode);
|
|
|
|
|
if (sparse_mode) {
|
|
|
|
|
size_t row_count = grad_merge.rows().size();
|
|
|
|
|
for (size_t row_index = 0; row_index < row_count; ++row_index) {
|
|
|
|
|
for (size_t offset = 0; offset < row_numel; ++offset) {
|
|
|
|
|
size_t i = rows[row_index] * row_numel + offset;
|
|
|
|
|
T g = grad_data[row_index * row_numel + offset];
|
|
|
|
|
functor.sparse_update(i, g);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
static_cast<const DeviceContext&>(ctx.device_context()),
|
|
|
|
|
param.numel());
|
|
|
|
|
for_range(functor);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Variable type not supported by adam_op");
|
|
|
|
|
}
|
|
|
|
|
|