|
|
|
@ -329,7 +329,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
std::unordered_map<std::string, int> sharded_var_device;
|
|
|
|
std::unordered_map<std::string, int> sharded_var_device;
|
|
|
|
|
|
|
|
|
|
|
|
for (ir::Node *node : sorted_ops) {
|
|
|
|
for (ir::Node *node : sorted_ops) {
|
|
|
|
VLOG(5) << "op name: " << node->Op()->Type();
|
|
|
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
if (boost::get<int>(
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
static_cast<int>(OpRole::kRPC)) {
|
|
|
|
static_cast<int>(OpRole::kRPC)) {
|
|
|
|
@ -366,11 +365,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
// is true only for the op that scale the final scalar loss.
|
|
|
|
// is true only for the op that scale the final scalar loss.
|
|
|
|
// It also assumes backward op will always follow the forward op in
|
|
|
|
// It also assumes backward op will always follow the forward op in
|
|
|
|
// the block.
|
|
|
|
// the block.
|
|
|
|
VLOG(5) << "this is loss scale op!";
|
|
|
|
|
|
|
|
is_forwarding = false;
|
|
|
|
is_forwarding = false;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
|
|
|
|
int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
|
|
|
|
VLOG(5) << "on device id: " << op_dev_id;
|
|
|
|
|
|
|
|
if (op_dev_id != -1) { // This op only runs on one specific device.
|
|
|
|
if (op_dev_id != -1) { // This op only runs on one specific device.
|
|
|
|
CreateComputationalOp(&result, node, op_dev_id);
|
|
|
|
CreateComputationalOp(&result, node, op_dev_id);
|
|
|
|
for (ir::Node *n : node->outputs) {
|
|
|
|
for (ir::Node *n : node->outputs) {
|
|
|
|
|