|
|
|
@ -11,11 +11,15 @@
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/rpc_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
|
|
|
@ -26,9 +30,6 @@
|
|
|
|
|
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
|
|
|
|
|
"the ssa graph path only print with GLOG_v=10,"
|
|
|
|
|
"default /tmp/graph.dot");
|
|
|
|
@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
std::unordered_map<std::string, proto::VarType::Type> var_types;
|
|
|
|
|
std::unordered_map<std::string, VarDesc *> all_vars;
|
|
|
|
|
for (auto *var : program.Block(0).AllVars()) {
|
|
|
|
|
var_types[var->Name()] = var->GetType();
|
|
|
|
|
all_vars[var->Name()] = var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto graph = new SSAGraph();
|
|
|
|
@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
auto send_vars = FindDistTrainSendVars(program);
|
|
|
|
|
auto recv_vars = FindDistTrainRecvVars(program);
|
|
|
|
|
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
std::vector<std::unordered_set<std::string>> var_name_on_devices;
|
|
|
|
|
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
|
|
|
|
|
var_name_on_devices.resize(places_.size());
|
|
|
|
|
bcast_var_name_set.resize(places_.size());
|
|
|
|
|
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
std::vector<int64_t> balance_grads(places_.size(), 0);
|
|
|
|
|
|
|
|
|
|
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
|
|
|
|
|
auto var_desc = all_vars.at(g_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_desc);
|
|
|
|
|
auto dim = framework::make_ddim(var_desc->GetShape());
|
|
|
|
|
int64_t numel = framework::product(dim);
|
|
|
|
|
PADDLE_ENFORCE_GE(numel, 0);
|
|
|
|
|
auto smallest =
|
|
|
|
|
std::min_element(std::begin(balance_grads), std::end(balance_grads));
|
|
|
|
|
size_t dev_id =
|
|
|
|
|
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
|
|
|
|
|
balance_grads[dev_id] += numel;
|
|
|
|
|
return dev_id;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
cur_device_id = get_appropriate_dev(g_name);
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
var_name_on_devices[cur_device_id].emplace(g_name);
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
cur_device_id = (cur_device_id + 1) % places_.size();
|
|
|
|
|
break;
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kAllReduce:
|
|
|
|
|
if (IsSparseGradient(var_types, g_name)) {
|
|
|
|
|
if (IsSparseGradient(all_vars, g_name)) {
|
|
|
|
|
CreateReduceOp(&result, g_name, 0);
|
|
|
|
|
CreateBroadcastOp(&result, g_name, 0);
|
|
|
|
|
} else {
|
|
|
|
@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsSparseGradient(
|
|
|
|
|
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
|
|
|
|
|
const std::unordered_map<std::string, VarDesc *> &all_vars,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
PADDLE_ENFORCE(var_types.count(og) != 0);
|
|
|
|
|
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
PADDLE_ENFORCE(all_vars.count(og) != 0);
|
|
|
|
|
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|