|
|
|
@ -81,7 +81,7 @@ ParameterOptimizer::TraverseCallback AverageOptimizer::needSpecialTraversal(
|
|
|
|
|
if (numUpdates_ % kMaxNumAccumulates == 0) {
|
|
|
|
|
// Move the sum to a different buffer to avoid loss of precision
|
|
|
|
|
// due to too many sums.
|
|
|
|
|
callbacks.emplace_back([this](const VectorPtr vecs[],
|
|
|
|
|
callbacks.emplace_back([](const VectorPtr vecs[],
|
|
|
|
|
const ParameterConfig& config,
|
|
|
|
|
size_t sparseId) {
|
|
|
|
|
vecs[PARAMETER_SUM2]->add(*vecs[PARAMETER_SUM1]);
|
|
|
|
@ -94,7 +94,7 @@ ParameterOptimizer::TraverseCallback AverageOptimizer::needSpecialTraversal(
|
|
|
|
|
if (auto callback = this->startCatchUpWith()) {
|
|
|
|
|
callbacks.emplace_back(callback);
|
|
|
|
|
}
|
|
|
|
|
callbacks.emplace_back([this](const VectorPtr vecs[],
|
|
|
|
|
callbacks.emplace_back([](const VectorPtr vecs[],
|
|
|
|
|
const ParameterConfig& config,
|
|
|
|
|
size_t sparseId) {
|
|
|
|
|
vecs[PARAMETER_SUM3]->add(*vecs[PARAMETER_SUM1], *vecs[PARAMETER_SUM2]);
|
|
|
|
|