|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include <Eigen/Dense>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/threadpool.h"
|
|
|
|
|
#include "paddle/fluid/operators/detail/safe_ref.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/algorithm.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
@ -352,10 +353,31 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
lr.template data<T>(), grad_data, param.template data<T>(),
|
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
|
|
|
|
|
grad_merge.rows().size());
|
|
|
|
|
int inner_op_parallelism = FLAGS_inner_op_parallelism;
|
|
|
|
|
if (inner_op_parallelism > 1 &&
|
|
|
|
|
FLAGS_min_param_size_to_use_multithread > 0 &&
|
|
|
|
|
param.numel() > FLAGS_min_param_size_to_use_multithread) {
|
|
|
|
|
std::vector<std::future<void>> fs;
|
|
|
|
|
int64_t block_size = param.numel() / inner_op_parallelism;
|
|
|
|
|
for (int i = 0; i < inner_op_parallelism; ++i) {
|
|
|
|
|
int64_t start = i * block_size;
|
|
|
|
|
int64_t end = (i + 1) * block_size;
|
|
|
|
|
if (end > param.numel()) {
|
|
|
|
|
end = param.numel();
|
|
|
|
|
}
|
|
|
|
|
fs.push_back(framework::Async([&functor, start, end]() {
|
|
|
|
|
for (int64_t i = start; i < end; ++i) {
|
|
|
|
|
functor(i);
|
|
|
|
|
}
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
|
|
|
|
|
} else {
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
static_cast<const DeviceContext&>(ctx.device_context()),
|
|
|
|
|
param.numel());
|
|
|
|
|
for_range(functor);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Variable type not supported by adam_op");
|
|
|
|
|
}
|
|
|
|
|