|
|
|
|
@ -202,7 +202,7 @@ struct SparseAdamFunctor {
|
|
|
|
|
row_count_(row_count),
|
|
|
|
|
sparse_mode_(sparse_mode) {}
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void sparse_update(size_t i, T g) const {
|
|
|
|
|
inline HOSTDEVICE void adam_update(size_t i, T g) const {
|
|
|
|
|
// The following code is the same as dense
|
|
|
|
|
T mom1 = moment1_[i];
|
|
|
|
|
T mom2 = moment2_[i];
|
|
|
|
|
@ -228,7 +228,7 @@ struct SparseAdamFunctor {
|
|
|
|
|
auto row_idx =
|
|
|
|
|
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
|
|
|
|
|
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
|
|
|
|
|
sparse_update(i, g);
|
|
|
|
|
adam_update(i, g);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@ -364,7 +364,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (size_t offset = 0; offset < row_numel; ++offset) {
|
|
|
|
|
size_t i = rows[row_index] * row_numel + offset;
|
|
|
|
|
T g = grad_data[row_index * row_numel + offset];
|
|
|
|
|
functor.sparse_update(i, g);
|
|
|
|
|
functor.adam_update(i, g);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
|