|
|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
|
|
|
|
@ -181,8 +182,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
// always use the first device
|
|
|
|
|
CreateRPCOp(&result, *op);
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
|
|
|
|
|
// CreateComputationalOps(&result, *op, 1);
|
|
|
|
|
CreateComputationalOp(&result, *op, 0);
|
|
|
|
|
CreateDistTrainOp(&result, *op);
|
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
|
// user can customize loss@grad if not use_default_grad_scale_
|
|
|
|
|
if (strategy_.gradient_scale_ !=
|
|
|
|
|
@ -247,9 +247,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
AddOutputToLeafOps(&result);
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
std::ostringstream sout;
|
|
|
|
|
PrintGraphviz(*graph, sout);
|
|
|
|
|
VLOG(10) << sout.str();
|
|
|
|
|
std::ofstream fout("/tmp/graph.dot");
|
|
|
|
|
PrintGraphviz(*graph, fout);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::unique_ptr<SSAGraph>(graph);
|
|
|
|
|
@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
CreateComputationalOp(result, op, 0);
|
|
|
|
|
if (op.Type() == "concat") {
|
|
|
|
|
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
auto &p = places_[0];
|
|
|
|
|
|