|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include "paddle/fluid/framework/details/reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
@ -26,6 +27,8 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
|
|
|
|
|
|
|
|
|
|
DECLARE_bool(use_mkldnn);
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
@ -55,6 +58,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
|
|
|
|
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
|
|
|
|
|
AppendPass("record_skip_memory_opt_vars_pass");
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (FLAGS_use_mkldnn) {
|
|
|
|
|
VLOG(5) << "Add mkldnn_placement_pass";
|
|
|
|
|
AppendPass("mkldnn_placement_pass");
|
|
|
|
|
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "mkldnn_enabled_op_types specify the operator type list to "
|
|
|
|
|
"use MKLDNN acceleration. It is null in default, means "
|
|
|
|
|
"that all the operators supported by MKLDNN will be "
|
|
|
|
|
"accelerated. And it should not be set when "
|
|
|
|
|
"FLAGS_use_mkldnn=false.";
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
|
|
|
|
|
"Please compile with MKLDNN first to use MKLDNN");
|
|
|
|
|
#endif
|
|
|
|
|
if (strategy_.enable_sequential_execution_) {
|
|
|
|
|
VLOG(5) << "Add sequential_execution_pass";
|
|
|
|
|
AppendPass("sequential_execution_pass");
|
|
|
|
@ -313,6 +332,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
|
|
|
|
|
} else if (pass->Type() == "inplace_pass") {
|
|
|
|
|
pass->Erase(ir::kUseCuda);
|
|
|
|
|
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
|
|
|
|
|
} else if (pass->Type() == "mkldnn_placement_pass") {
|
|
|
|
|
pass->Set("mkldnn_enabled_op_types",
|
|
|
|
|
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "Start Apply Pass " << pass->Type();
|
|
|
|
|
graph = pass->Apply(graph);
|
|
|
|
@ -351,3 +373,6 @@ USE_PASS(fuse_all_reduce_op_pass);
|
|
|
|
|
USE_PASS(runtime_context_cache_pass);
|
|
|
|
|
USE_PASS(expected_kernel_cache_pass);
|
|
|
|
|
USE_PASS(record_skip_memory_opt_vars_pass);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
USE_PASS(mkldnn_placement_pass);
|
|
|
|
|
#endif
|
|
|
|
|