Enable the runtime_context_cache pass in train phase (#16640)

* Try to enable the runtime_context_cache pass in train phase.

* Put the append of runtime_context_cache pass ahead of multi_dev passes.
test=develop
revert-16734-refine/test_imperative_transformer
Yiqun Liu 6 years ago committed by GitHub
parent 4048a2681f
commit 3fe8cb0dd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -142,6 +142,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("memory_optimize_pass");
}
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(10) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) {
@ -328,3 +336,4 @@ USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);

@ -107,6 +107,8 @@ struct BuildStrategy {
std::vector<std::string> trainers_endpoints_;
bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false};
// NOTE:
// Before you add new options, think if it's a general strategy that works
// with other strategy. If not, the strategy should be created through

@ -23,7 +23,7 @@ namespace ir {
void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
}
}

@ -1356,6 +1356,10 @@ All parameter, weight, gradient are variables in Paddle.
"fuse_all_reduce_ops",
[](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; },
[](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; })
.def_property(
"cache_runtime_context",
[](const BuildStrategy &self) { return self.cache_runtime_context_; },
[](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; })
.def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true);

Loading…
Cancel
Save