Enable the detection of subgraph composed of grad ops (#21223)

* Add the first implememtation of fusion_group op #19621 (#3)

* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc.
test=develop

* Call CUDA driver api to launch the kernel compiled by nvrtc.
test=develop

* Disable for mac and windows.
test=develop

* Refine the codes to support manually specified num_threads and workload_per_thread.
test=develop

* Refine the CUDA kernel to support large dims.
test=develop

* Add DeviceCodePool to manage all device codes.

* Add the first implementation fusion_group op.

* Add unit-test for fusion_group op.

* Add the check of result.

* Add the check of nvrtc in unit-test.
test=develop

* Add comment to explain the inputs, outputs and features of fusion_group op.
test=develop

* Disable fusion_group op for mac and windows.
test=develop

* Make the compiling of device code return status instead of hanging up.
test=develop

* Add the check of whether there is CUDA driver library, and do not core dump when failing to call the CUDA driver API.

* Unify fusion_group_op's input and output names.
test=develop

* Add the check of CUDA driver library in unittest.
test=develop

* Enable generating code for a given subgraph. #21126 (#4)

* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop

* Enable the detection of subgraph of grad ops.

* Generate code for detected subgraph in fusion_group_pass.

* Add an option in BuildStrategy to enable fusion_group_pass and add unittest.
test=develop

* Fix a bug when checking whether the shape of all inputs are the same.

* Add debug information.

* Remove subgraph_detector from inference/analysis to the common framework/ir directory. (#5)

test=develop

* Call subgraph_detector in fusion_group pass.
test=develop

* Disable fusion_group when WITH_GPU is OFF.
test=develop

* Refine all PADDLE_ENFORCE message.
test=develop

* Fix the case that some inputs are not defined in grad ops, and set op_role for fused op.
test=develop

* Follow review comments.
test=develop
revert-22710-feature/integrated_ps_api
Yiqun Liu 5 years ago committed by GitHub
parent 50af6b5d79
commit dcfb603897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,7 +64,14 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
sequential_execution_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
reference_count_pass
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
@ -91,23 +98,22 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass
multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass)
if(WITH_GPU)
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif()
if(WITH_NGRAPH)
set(NGRAPH_BS_DEPS ngraph)
else()
set(NGRAPH_BS_DEPS)
set(IR_PASS_DEPS ${IR_PASS_DEPS} ngraph)
endif()
cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass
pass_builder
${NGRAPH_BS_DEPS})
cc_library(build_strategy SRCS build_strategy.cc DEPS pass_builder ${IR_PASS_DEPS})
if (WITH_MKLDNN)
target_link_libraries(build_strategy mkldnn_placement_pass)

@ -165,9 +165,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendOpFusePasses() {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#endif
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
@ -370,6 +373,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "fusion_group_pass") {
pass->Set<bool>("use_gpu", new bool(use_cuda));
if (!use_cuda) {
LOG(WARNING) << "fusion_group_pass is only supported on GPU, skipped.";
continue;
}
} else if (pass->Type() == "fuse_bn_act_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_bn_act_pass is only supported on "
@ -427,3 +436,6 @@ USE_PASS(mkldnn_placement_pass);
#ifdef PADDLE_WITH_NGRAPH
USE_PASS(ngraph_subgraph_pass);
#endif
#ifdef PADDLE_WITH_CUDA
USE_PASS(fusion_group_pass);
#endif

@ -86,8 +86,9 @@ struct BuildStrategy {
// Operator fusion
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
bool fuse_elewise_add_act_ops_{false};
bool fuse_bn_act_ops_{false};
bool fuse_elewise_add_act_ops_{false};
bool enable_auto_fusion_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
boost::optional<bool> fuse_all_optimizer_ops_{false};

@ -6,7 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
if(NOT APPLE AND NOT WIN32)
if(NOT APPLE AND NOT WIN32 AND WITH_GPU)
add_subdirectory(fusion_group)
endif()

@ -1,9 +1,11 @@
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
cc_library(code_generator
SRCS operation.cc code_generator.cc code_generator_helper.cc
DEPS graph subgraph_detector)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif()
cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass code_generator)
DEPS subgraph_detector fuse_pass_base code_generator device_code)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)

@ -33,7 +33,7 @@ CodeGenerator::CodeGenerator() {
std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions);
return Generate(subgraph->GetFuncName(), expressions);
}
static bool HasInput(Node* n, std::string name) {

@ -227,7 +227,7 @@ std::vector<fusion_group::OperationExpression> TestMain(
std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str;
TestMainImpl(subgraph->func_name, code_str, cpu_tensors, n, input_ids,
TestMainImpl(subgraph->GetFuncName(), code_str, cpu_tensors, n, input_ids,
output_ids);
// Need to check the accuracy according to expressions.

@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
namespace paddle {
namespace framework {
@ -26,20 +29,22 @@ static std::unordered_set<std::string> unary_op_types;
static std::unordered_set<std::string>& GetBinaryOpTypes() {
if (binary_op_types.empty()) {
binary_op_types = OperationMap::Instance().Find(0, 2);
binary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
}
return binary_op_types;
}
static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) {
unary_op_types = OperationMap::Instance().Find(0, 1);
unary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
}
return unary_op_types;
}
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) {
const Node* n) {
if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) {
auto iter = op_types.find(n->Op()->Type());
if (iter != op_types.end()) {
@ -49,114 +54,63 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
return false;
}
static bool IsBinaryOp(Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n) && n->inputs.size() == 2U) {
auto* x = n->inputs[0];
auto* y = n->inputs[1];
static bool IsGradOp(const Node* n) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
std::string suffix = "_grad";
std::string op_type = n->Op()->Type();
size_t pos = op_type.rfind(suffix);
return pos != std::string::npos &&
pos == (op_type.length() - suffix.length());
}
std::vector<int64_t> x_shape;
std::vector<int64_t> y_shape;
if (x && x->IsVar() && x->Var()) {
x_shape = x->Var()->GetShape();
}
if (y && y->IsVar() && y->Var()) {
y_shape = y->Var()->GetShape();
}
if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) {
static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
const std::vector<int64_t>& r) {
return l.size() != 0U && r.size() != 0U && l == r;
}
static bool IsBinaryOp(const Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n)) {
if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
return false;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i]) {
// The shape of all inputs should be the same.
std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i];
if (!(in_i && in_i->IsVar() && in_i->Var())) {
return false;
}
}
return true;
}
return false;
}
static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(GetUnaryOpTypes(), n); }
bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n,
std::string name) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->outputs) {
if (IsElementwiseOp(op)) {
if (name.empty()) {
return true;
} else if (IsNthInput(n, op, name, 0)) {
return true;
std::vector<int64_t> shape_i = in_i->Var()->GetShape();
if (i == 0U) {
shape_0 = shape_i;
} else {
if (!IsEqualAndNotEmpty(shape_0, shape_i)) {
return false;
}
}
}
return true;
}
return false;
}
bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op)) {
return true;
}
}
}
return false;
static bool IsUnaryOp(const Node* n) {
return IsSpecifiedOp(GetUnaryOpTypes(), n);
}
int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) {
except_nodes_set.insert(except_nodes[i]);
}
int num_operations = 0;
if (IsElementwiseOp(n)) {
subgraph_.Insert(n);
num_operations += 1;
for (auto* var : n->inputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
for (auto* var : n->outputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
} else if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
for (auto* op : n->outputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
}
return num_operations;
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
int ElementwiseGroupDetector::operator()(Node* n) {
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
name_ = n->Name();
subgraph_.Insert(n);
num_operations_ = Search(n, n->inputs);
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes()
<< " nodes";
}
return num_operations_;
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };
return SubgraphDetector(graph, teller)();
}
} // namespace fusion_group

@ -14,10 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
@ -27,21 +25,10 @@ namespace fusion_group {
class ElementwiseGroupDetector {
public:
int operator()(Node* n);
SubGraph GetSubgraph() const { return subgraph_; }
private:
bool IsElementwiseOp(Node* n);
bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n);
int Search(Node* n, std::vector<Node*> except_nodes = {});
std::vector<std::vector<Node*>> operator()(Graph* graph);
private:
std::string name_;
int num_operations_{0};
SubGraph subgraph_;
bool IsElementwiseOp(const Node* n);
};
} // namespace fusion_group

@ -13,57 +13,88 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h"
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/device_code.h"
namespace paddle {
namespace framework {
namespace ir {
void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
int num_elementwise_groups = DetectFusionGroup(graph, 0);
LOG(INFO) << "Detect " << num_elementwise_groups
FusePassBase::Init("fusion_group_pass", graph);
if (Get<bool>("use_gpu")) {
fusion_group::OperationMap::Init();
int num_elementwise_groups = DetectFusionGroup(graph, 0);
VLOG(3) << "Detect " << num_elementwise_groups
<< " elementwise fusion groups.";
}
}
int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std::vector<fusion_group::SubGraph> subgraphs;
std::unordered_set<Node*> all_nodes = graph->Nodes();
for (Node* n : all_nodes) {
bool is_found = false;
for (auto& subgraph : subgraphs) {
if (subgraph.Has(n)) {
is_found = true;
break;
}
}
if (is_found) {
continue;
// TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0);
int index = platform::DeviceCodePool::Init({place}).size(place);
std::vector<std::vector<Node*>> subgraphs =
fusion_group::ElementwiseGroupDetector()(graph);
int num_subgraphs = 0;
size_t min_subgraph_size = 2;
bool save_intermediate_out = true;
for (auto& vec : subgraphs) {
if (vec.size() >= min_subgraph_size) {
std::string func_name = "fused_elementwise_" + std::to_string(index++);
fusion_group::SubGraph subgraph(
type, func_name, save_intermediate_out,
std::unordered_set<Node*>(vec.begin(), vec.end()));
VLOG(3) << "subgraph: {\n"
<< DebugString(subgraph.SortedNodes()) << "}\n";
GenerateCode(&subgraph);
InsertFusionGroupOp(graph, &subgraph);
num_subgraphs++;
}
}
return num_subgraphs;
}
fusion_group::SubGraph subgraph;
if (type == 0) {
fusion_group::ElementwiseGroupDetector detector;
int num_operations = detector(n);
if (num_operations >= 2) {
subgraph = detector.GetSubgraph();
}
}
void FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str;
// TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0);
std::unique_ptr<platform::CUDADeviceCode> device_code(
new platform::CUDADeviceCode(place, subgraph->GetFuncName(), code_str));
device_code->Compile();
platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place});
pool.Set(std::move(device_code));
}
if (!subgraph.IsEmpty()) {
subgraphs.push_back(subgraph);
static int ExtractOpRole(fusion_group::SubGraph* subgraph) {
std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
for (auto* n : subgraph->Nodes()) {
if (n && n->IsOp() && n->Op()) {
if (n->Op()->HasAttr(attr_name)) {
op_roles.insert(boost::get<int>(n->Op()->GetAttr(attr_name)));
}
}
}
// TODO(liuyiqun): check whether there are intersection between subgraphs
for (size_t i = 0; i < subgraphs.size(); ++i) {
InsertFusionGroupOp(graph, &subgraphs[i]);
if (op_roles.size() == 1U) {
return *(op_roles.begin());
} else {
return static_cast<int>(OpRole::kNotSpecified);
}
return subgraphs.size();
}
void FusionGroupPass::InsertFusionGroupOp(
@ -90,10 +121,12 @@ void FusionGroupPass::InsertFusionGroupOp(
external_nodes.insert(n);
}
op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("type", subgraph->type);
op_desc.SetAttr("func_name", subgraph->func_name);
op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->GetFuncName());
op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(subgraph));
auto fusion_group_node = graph->CreateOpNode(&op_desc);
Node* fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) {
IR_NODE_LINK_TO(in, fusion_group_node);
}
@ -114,4 +147,5 @@ void FusionGroupPass::InsertFusionGroupOp(
} // namespace framework
} // namespace paddle
REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass);
REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass)
.RequirePassAttr("use_gpu");

@ -16,19 +16,20 @@ limitations under the License. */
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class FusionGroupPass : public Pass {
class FusionGroupPass : public FusePassBase {
protected:
void ApplyImpl(Graph* graph) const override;
private:
int DetectFusionGroup(Graph* graph, int type = 0) const;
void GenerateCode(fusion_group::SubGraph* subgraph) const;
void InsertFusionGroupOp(Graph* graph,
fusion_group::SubGraph* subgraph) const;

@ -138,19 +138,15 @@ int TestMain(std::unique_ptr<Graph> graph, std::string prefix) {
}
TEST(FusionGroupPass, elementwise_list) {
fusion_group::OperationMap::Init();
std::unique_ptr<Graph> graph = BuildElementwiseListGraph(false);
std::unique_ptr<Graph> graph = BuildElementwiseListGraph(true);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_list");
EXPECT_EQ(num_fusion_group_ops, 1);
EXPECT_EQ(num_fusion_group_ops, 2);
}
TEST(FusionGroupPass, elementwise_tree) {
fusion_group::OperationMap::Init();
std::unique_ptr<Graph> graph = BuildElementwiseTreeGraph(false);
std::unique_ptr<Graph> graph = BuildElementwiseTreeGraph(true);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_tree");
EXPECT_EQ(num_fusion_group_ops, 2);
EXPECT_EQ(num_fusion_group_ops, 4);
}
} // namespace ir

File diff suppressed because it is too large Load Diff

@ -2017,6 +2017,27 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_bn_act_ops = True
)DOC")
.def_property(
"enable_auto_fusion",
[](const BuildStrategy &self) { return self.enable_auto_fusion_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
platform::errors::PreconditionNotMet(
"BuildStrategy is finlaized."));
self.enable_auto_fusion_ = b;
},
R"DOC((bool, optional): Whether to enable fusing subgraph to a
fusion_group. Now we only support fusing subgraph that composed
of elementwise-like operators, such as elementwise_add/mul
without broadcast and activations.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.enable_auto_fusion = True
)DOC")
.def_property(
"fuse_relu_depthwise_conv",
[](const BuildStrategy &self) {

@ -200,6 +200,10 @@ if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler)
endif()
if(NOT WITH_GPU OR WIN32 OR APPLE)
list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass)
endif()
# Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC

@ -0,0 +1,39 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase
class FusionGroupPaddingRNNTest(PaddingRNNTestBase):
def set_customed_config(self):
self.build_strategy.enable_auto_fusion = True
# Use CUDA executor
if core.is_compiled_with_cuda():
self.exe = fluid.Executor(fluid.CUDAPlace(0))
def test_train_enable_fusion_group(self):
rnn_model = "static"
config = RNNConfig("test", rnn_model)
with fluid.scope_guard(fluid.Scope()):
self.train(config, parallel=True, use_program_cache=False)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save