|
|
|
@ -27,7 +27,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
|
|
|
|
public:
|
|
|
|
|
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
|
|
|
|
|
: ir::PassBuilder(), strategy_(strategy) {
|
|
|
|
|
// Apply a graph viz pass to record a graph.
|
|
|
|
|
// Add a graph viz pass to record a graph.
|
|
|
|
|
if (!strategy_.debug_graphviz_path_.empty()) {
|
|
|
|
|
auto viz_pass = AppendPass("graph_viz_pass");
|
|
|
|
|
const std::string graph_path = string::Sprintf(
|
|
|
|
@ -35,10 +35,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
|
|
|
|
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Apply op fusion.
|
|
|
|
|
// Add op fusion.
|
|
|
|
|
if (strategy.fuse_elewise_add_act_ops_) {
|
|
|
|
|
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
|
|
|
|
|
// Apply a graph viz pass to record a graph.
|
|
|
|
|
// Add a graph viz pass to record a graph.
|
|
|
|
|
if (!strategy.debug_graphviz_path_.empty()) {
|
|
|
|
|
auto viz_pass = AppendPass("graph_viz_pass");
|
|
|
|
|
const std::string graph_path = string::Sprintf(
|
|
|
|
@ -53,7 +53,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
|
|
|
|
|
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
|
|
|
|
|
&strategy_);
|
|
|
|
|
|
|
|
|
|
// Apply a graph print pass to record a graph with device info.
|
|
|
|
|
// Add a graph print pass to record a graph with device info.
|
|
|
|
|
if (!strategy_.debug_graphviz_path_.empty()) {
|
|
|
|
|
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
|
|
|
|
|
multi_devices_print_pass->SetNotOwned<const std::string>(
|
|
|
|
@ -86,7 +86,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
|
|
|
|
|
#else
|
|
|
|
|
const bool use_cuda) const {
|
|
|
|
|
#endif
|
|
|
|
|
// Create a default one if not intialized by user.
|
|
|
|
|
// Create a default one if not initialized by user.
|
|
|
|
|
if (!pass_builder_) {
|
|
|
|
|
CreatePassesFromStrategy();
|
|
|
|
|
}
|
|
|
|
|