|
|
|
@ -24,6 +24,7 @@ limitations under the License. */
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "gflags/gflags.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/scope.h"
|
|
|
|
|
#include "paddle/fluid/framework/variable.h"
|
|
|
|
@ -37,6 +38,8 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
#include "paddle/fluid/platform/place.h"
|
|
|
|
|
|
|
|
|
|
DECLARE_bool(communicator_is_sgd_optimizer);
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace distributed {
|
|
|
|
@ -138,8 +141,10 @@ inline void MergeVars(const std::string& var_name,
|
|
|
|
|
auto in = EigenVector<float>::Flatten(in_t);
|
|
|
|
|
result.device(*cpu_ctx.eigen_device()) = result + in;
|
|
|
|
|
}
|
|
|
|
|
result.device(*cpu_ctx.eigen_device()) =
|
|
|
|
|
result / static_cast<float>(vars.size());
|
|
|
|
|
if (!FLAGS_communicator_is_sgd_optimizer) {
|
|
|
|
|
result.device(*cpu_ctx.eigen_device()) =
|
|
|
|
|
result / static_cast<float>(vars.size());
|
|
|
|
|
}
|
|
|
|
|
} else if (var0->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto& slr0 = var0->Get<framework::SelectedRows>();
|
|
|
|
|
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
|
|
|
|
@ -151,9 +156,16 @@ inline void MergeVars(const std::string& var_name,
|
|
|
|
|
inputs.push_back(&var->Get<framework::SelectedRows>());
|
|
|
|
|
}
|
|
|
|
|
auto dev_ctx = paddle::platform::CPUDeviceContext();
|
|
|
|
|
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
|
|
|
|
|
merge_average;
|
|
|
|
|
merge_average(dev_ctx, inputs, out_slr);
|
|
|
|
|
if (FLAGS_communicator_is_sgd_optimizer) {
|
|
|
|
|
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
|
|
|
|
|
merge_add;
|
|
|
|
|
merge_add(dev_ctx, inputs, out_slr);
|
|
|
|
|
} else {
|
|
|
|
|
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
|
|
|
|
|
merge_average;
|
|
|
|
|
merge_average(dev_ctx, inputs, out_slr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
|
|
|
|
|
<< " dims: " << slr0.value().dims();
|
|
|
|
|
} else {
|
|
|
|
|