Fix the interface of Pass::Apply ()

* modify the interface of Pass::Allay
test=develop

* Polish code
test=develop

* Fix Travis CI
test=develop

* fix Pass::Apply interface
test=develop

* Fix Travis CI
test=develop
move-code
chengduo 6 years ago committed by GitHub
parent 59f75ec76e
commit ed61d67c73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) {
return nullptr;
}
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order
@ -131,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString();
}
return graph;
}
} // namespace details

@ -24,8 +24,7 @@ namespace details {
// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details

@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype =
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return std::move(graph);
return;
}
std::unordered_map<std::string, ir::Node *> vars;
@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
return std::move(graph);
}
template <typename AttrType>

@ -204,15 +204,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const {
#else
const bool use_cuda) const {
const bool use_cuda) const {
#endif
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
}
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph));
graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type();
}
return graph;

@ -120,16 +120,15 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
#else
const bool use_cuda) const;
const bool use_cuda) const;
#endif
// If set true, ParallelExecutor would build the main_program into multiple

@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars(
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
while_op_eager_deletion_pass->Apply(graph);
}
} // namespace details

@ -28,8 +28,7 @@ namespace details {
class FuseAllReduceOpPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) {
return std::move(graph);
return;
}
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass {
group_all_reduce_ops, &result);
#endif
}
return std::move(graph);
}
void InsertFusedAllReduce(const std::vector<platform::Place> &places,

@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const {
}
}
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void InplacePass::ApplyImpl(ir::Graph* graph) const {
var_nodes_.clear();
view_.Build(graph.get());
view_.Build(graph);
InitSSAGraphNodes();
auto cnt = 0;
@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue;
TryInplaceOpInputOutput(op, graph.get());
TryInplaceOpInputOutput(op, graph);
}
// graph->ResolveHazard(var_nodes_);
return graph;
}
void InplacePass::InplaceModifyDesc(const std::string& var,

@ -69,8 +69,7 @@ class InplacePass : public ir::Pass {
InplacePass();
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
void InitSSAGraphNodes() const;

@ -44,8 +44,7 @@ namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
auto nodes = graph->Nodes();
CollectSkipVarsSet(nodes);
@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
RenameVarInGraphNode(var_name, cache_name, idx, graph);
pool_.Erase(cache_name);
}
}
@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
graph->ResolveHazard(var_nodes_);
return graph;
}
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {

@ -21,6 +21,7 @@
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -35,8 +36,7 @@ namespace details {
class MemoryOptimizePass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
// fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const;

@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details

@ -23,8 +23,7 @@ namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details

@ -23,10 +23,8 @@ namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
void ApplyImpl(ir::Graph *graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph));
}
bool IsValidGraph(const ir::Graph *graph) const {

@ -153,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
Init();
CheckGraph(*graph);
std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
@ -236,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
return graph;
}
void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(

@ -36,8 +36,7 @@ namespace details {
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
virtual void Init() const;

@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <fstream>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
}
};

@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
}
// set the correct size of thread pool to each device.

@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
@ -342,8 +341,6 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Just skip this corner case
}
}
return graph;
}
} // namespace details

@ -23,8 +23,7 @@ namespace details {
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details

@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details

@ -23,8 +23,7 @@ namespace details {
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details

@ -23,8 +23,7 @@ namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};

@ -29,10 +29,9 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out);
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
GraphSafeRemoveNodes(graph,
{fill_constant, fill_constant_out, elementwise_mul});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save