Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into port_python3_syntax

revert-12646-feature/jit/xbyak
minqiyang 7 years ago
commit 000ba1ac5f

@ -46,6 +46,7 @@
| tianbingsz | Tian-Bing Xu |
| tpatejko | Tomasz Patejko |
| typhoonzero | Yi Wu |
| velconia | Qi-Yang Min |
| wanghaoshuang | Hao-Shuang Wang |
| wangyang59 | Yang Wang |
| wangzhen-nlp | Zhen Wang |

@ -85,8 +85,7 @@ def dist_transpile(trainer_id, args):
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=not args.async_mode,
slice_var_up=not args.no_split_var)
sync_mode=not args.async_mode)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(current_endpoint)
pserver_startup_program = t.get_startup_program(current_endpoint,

@ -50,7 +50,7 @@ ExternalProject_Add(
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
PATCH_COMMAND git apply ${PADDLE_SOURCE_DIR}/patches/grpc/fix_too_early_destory.patch
PATCH_COMMAND cp ${PADDLE_SOURCE_DIR}/patches/grpc/grpc_library.h ${GRPC_SOURCES_DIR}/src/extern_grpc/include/grpcpp/impl/codegen/grpc_library.h && cp ${PADDLE_SOURCE_DIR}/patches/grpc/completion_queue.h ${GRPC_SOURCES_DIR}/src/extern_grpc/include/grpcpp/impl/codegen/completion_queue.h
# NOTE(yuyang18):
# Disable -Werror, otherwise the compile will fail in MacOS.
# It seems that we cannot configure that by make command.

@ -263,7 +263,7 @@ function(cc_test TARGET_NAME)
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY SERIAL 1)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
endif()
endif()
@ -328,7 +328,7 @@ function(nv_test TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY SERIAL 1)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
endif()
endif()

@ -148,18 +148,11 @@ if (WITH_ANAKIN AND WITH_GPU)
list(APPEND inference_deps anakin_inference_lib)
endif()
copy(inference_api_lib DEPS paddle_inference_api paddle_inference_api_shared
SRCS ${src_dir}/${module}/paddle_inference_api.h
${src_dir}/${module}/demo_ci
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/libpaddle_inference_api*
DSTS ${dst_dir}/inference ${dst_dir}/inference ${dst_dir}/inference
)
list(APPEND inference_deps inference_api_lib)
set(module "inference")
copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
DSTS ${dst_dir}/${module} ${dst_dir}/${module}
${src_dir}/${module}/api/paddle_inference_api.h ${src_dir}/${module}/api/demo_ci
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
)
set(module "platform")

@ -64,6 +64,41 @@ can also contain other things that describe some properties of
the `Graph` or `Graph` nodes. `Attribute` can be passed
across `Pass`. However, it should be used with care.
```cpp
class Graph {
public:
explicit Graph(const ProgramDesc &program);
bool Has(const std::string &attr_name) const;
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const;
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr);
const std::unordered_set<ir::Node *> &Nodes() const;
// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc);
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc);
// Create a control dependency var that connects 2 operations. The
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir::Node *CreateControlDepVar();
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type);
// Clear all node information of the graph and return the ownership of the
// nodes.
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes();
};
```
#### Pass
`Pass` represents a transformation of `Graph`. Its input
@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example,
a `Pass` can simply print out the `Graph`. A `Pass`
can also fuse some `Graph`'s `Node`s.
```cpp
class Pass {
public:
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
// Some correctness check.
auto new_graph = ApplyImpl(std::move(graph));
// Some correctness check.
return new_graph;
}
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const;
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) ;
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr);
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
};
// In my_pass.cc
class MyPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override {
// do something.
return graph;
}
}
REGISTER_PASS(my_pass, MyPass)
.RequirePassAttr("places")
.RequireGraphAttr("dep_vars");
// To use the pass.
auto my_pass = ir::PassRegistry::Instance().Get("my_pass");
graph = my_pass->Apply(std::move(graph));
// Note: to force link my_pass.cc, in the code:
USE_PASS(my_pass);
```
#### Optimize
`Optimize` contains a series of `Pass` with defined order.
@ -86,4 +169,17 @@ maintaining the original modeling logic.
* Graph is transformed from raw model logic to a
form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
```
// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
auto graph = Graph(program);
graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
// For more complex Pass, Optimize Process can provide Pass attributes.
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
mem_opt_pass->Apply(std::move(graph));
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
Executor exe;
exe.Run(graph);
```

@ -8,9 +8,9 @@ cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
if(WITH_GPU)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context)
else()
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type)
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type device_context)
endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
@ -99,7 +99,7 @@ else()
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)

@ -31,39 +31,27 @@ class Scope;
namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs,
const BuildStrategy &strategy);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy);
#endif
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private:
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
void Init() const;
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_;
std::unordered_set<std::string> grad_names_;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const;
bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(ir::Node *node) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::string> &var_names) const;
private:
BuildStrategy strategy_;
mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle,

@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
const ir::Graph& Graph() const { return underlying_executor_->Graph(); }
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:

@ -18,7 +18,7 @@ namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
}
}
@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr;
if (var_holder.empty()) {
@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
auto &vars =
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
size_t version = vars.size();
auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
}
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) {
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}

@ -39,21 +39,25 @@ namespace details {
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
class SSAGraphBuilder : public ir::Pass {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected:

@ -1,50 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res(
#ifdef PADDLE_WITH_CUDA
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, nccl_ctxs_, strategy_)
#else
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, strategy_)
#endif
); // NOLINT
if (!strategy_.debug_graphviz_path_.empty()) {
std::unique_ptr<std::ostream> fout(
new std::ofstream(strategy_.debug_graphviz_path_));
PADDLE_ENFORCE(fout->good());
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
new GraphvizSSAGraphPrinter());
res.reset(new SSAGraghBuilderWithPrinter(
std::move(fout), std::move(graphviz_printer), std::move(res)));
}
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle

@ -1,71 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
class Scope;
namespace details {
class SSAGraphBuilderFactory {
public:
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes,
const BuildStrategy& strategy)
: places_(places),
loss_var_name_(loss_var_name),
param_names_(param_names),
local_scopes_(local_scopes),
strategy_(strategy) {
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = nullptr;
#endif
}
#ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
nccl_ctxs_ = nccl_ctxs;
}
#endif
std::unique_ptr<SSAGraphBuilder> Create();
private:
std::vector<platform::Place> places_;
std::string loss_var_name_;
std::unordered_set<std::string> param_names_;
std::vector<Scope*> local_scopes_;
BuildStrategy strategy_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap* nccl_ctxs_;
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
};
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get());
}
for (auto &op : graph->Get<GraphOps>("ops")) {
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);

@ -23,26 +23,14 @@ namespace framework {
namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
explicit SSAGraghBuilderWithChecker(
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const ir::Graph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
};
} // namespace details

@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual ~SSAGraphExecutor();
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
virtual const ir::Graph& Graph() const = 0;
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
};
} // namespace details
} // namespace framework

@ -22,7 +22,7 @@ namespace details {
template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
callback(*var);
}
}
@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>("ops")) {
for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_device_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter);

@ -14,7 +14,9 @@
#pragma once
#include <fstream>
#include <iosfwd>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
SSAGraghBuilderWithPrinter(std::ostream& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ref_(sout) {}
SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_);
return new_graph;
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<const std::string>("debug_graphviz_path")));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
return builder_->GetVarDeviceID(var_name);
}
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
std::unique_ptr<std::ostream> stream_ptr_;
std::ostream& stream_ref_;
};
} // namespace details

@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get());
}
for (auto &op : graph_->Get<details::GraphOps>("ops")) {
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());

@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph);
const ir::Graph &Graph() const { return *graph_; }
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;

@ -1,6 +1,9 @@
cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)

@ -24,6 +24,68 @@ namespace paddle {
namespace framework {
namespace ir {
std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto &node : nodes) {
auto op_vars = node->Op()->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
return send_vars;
}
std::vector<std::string> FindDistTrainRecvVars(
const std::vector<ir::Node *> &nodes) {
std::vector<std::string> recv_vars;
for (auto &node : nodes) {
auto op_vars = node->Op()->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
return recv_vars;
}
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
}
/**
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
return checker(output_var_names, send_vars) ||
checker(input_var_names, recv_vars);
}
Graph::Graph(const ProgramDesc &program) : program_(program) {
VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars;
@ -61,6 +123,64 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
var->inputs.push_back(node);
}
}
std::vector<ir::Node *> send_ops;
ir::Node *send_bar = nullptr;
std::vector<ir::Node *> recv_ops;
ir::Node *fetch_bar = nullptr;
for (ir::Node *node : Nodes()) {
if (node->Name() == "send") {
send_ops.push_back(node);
} else if (node->Name() == "send_barrier") {
PADDLE_ENFORCE(!send_bar, "only has one send barrier");
send_bar = node;
} else if (node->Name() == "recv") {
recv_ops.push_back(node);
} else if (node->Name() == "fetch_barrier") {
PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier");
fetch_bar = node;
}
}
if (send_bar) {
for (ir::Node *send : send_ops) {
ir::Node *dep_var = CreateControlDepVar();
send->outputs.push_back(dep_var);
dep_var->inputs.push_back(send);
send_bar->inputs.push_back(dep_var);
dep_var->outputs.push_back(send_bar);
}
for (ir::Node *recv : recv_ops) {
ir::Node *dep_var = CreateControlDepVar();
recv->inputs.push_back(dep_var);
dep_var->outputs.push_back(recv);
send_bar->outputs.push_back(dep_var);
dep_var->inputs.push_back(send_bar);
}
}
if (fetch_bar) {
for (ir::Node *recv : recv_ops) {
ir::Node *dep_var = CreateControlDepVar();
recv->outputs.push_back(dep_var);
dep_var->inputs.push_back(recv);
fetch_bar->inputs.push_back(dep_var);
dep_var->outputs.push_back(fetch_bar);
}
}
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
for (ir::Node *node : Nodes()) {
if (IsDistTrainOp(node, send_vars, recv_vars)) {
if (fetch_bar && node->Name() == "concat") {
ir::Node *dep_var = CreateControlDepVar();
fetch_bar->outputs.push_back(dep_var);
dep_var->inputs.push_back(fetch_bar);
node->inputs.push_back(dep_var);
dep_var->outputs.push_back(node);
}
}
}
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need

@ -40,14 +40,21 @@ class Graph {
attr_dels_.clear();
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph",
attr_name);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save