|
|
|
@ -305,13 +305,6 @@ struct SparseAdamFunctor<T, CPUAdam> {
|
|
|
|
|
param_out_[i] = p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void update_row(size_t row_id, int grad_row_offset) const {
|
|
|
|
|
for (size_t i = 0U; i < row_numel_; ++i) {
|
|
|
|
|
T g = grad_row_offset >= 0 ? grad_[grad_row_offset * row_numel_ + i] : 0;
|
|
|
|
|
adam_update(row_id * row_numel_ + i, g);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void operator()(size_t numel) const {
|
|
|
|
|
// lr could be reuse
|
|
|
|
|
T lr = *lr_;
|
|
|
|
@ -502,9 +495,6 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
"multi thread, currently "
|
|
|
|
|
<< param_row_count;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < param_row_count; ++i) {
|
|
|
|
|
row_id_to_grad_row_offset[i] = -1;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < grad_rows.size(); ++i) {
|
|
|
|
|
row_id_to_grad_row_offset[grad_rows[i]] = i;
|
|
|
|
|
}
|
|
|
|
@ -520,10 +510,24 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (end > param_row_count) {
|
|
|
|
|
end = param_row_count;
|
|
|
|
|
}
|
|
|
|
|
fs.push_back(framework::Async(
|
|
|
|
|
[&functor, &row_id_to_grad_row_offset, start, end]() {
|
|
|
|
|
for (int64_t i = start; i < end; ++i) {
|
|
|
|
|
functor.update_row(i, row_id_to_grad_row_offset[i]);
|
|
|
|
|
fs.push_back(
|
|
|
|
|
framework::Async([&functor, &row_id_to_grad_row_offset,
|
|
|
|
|
&grad_data, row_numel, start, end]() {
|
|
|
|
|
for (int64_t row_id = start; row_id < end; ++row_id) {
|
|
|
|
|
auto iter = row_id_to_grad_row_offset.find(row_id);
|
|
|
|
|
if (iter != row_id_to_grad_row_offset.end()) {
|
|
|
|
|
for (size_t row_offset = 0U; row_offset < row_numel;
|
|
|
|
|
++row_offset) {
|
|
|
|
|
functor.adam_update(
|
|
|
|
|
row_id * row_numel + row_offset,
|
|
|
|
|
grad_data[iter->second * row_numel + row_offset]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t row_offset = 0U; row_offset < row_numel;
|
|
|
|
|
++row_offset) {
|
|
|
|
|
functor.adam_update(row_id * row_numel + row_offset, 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|