add Fuse bn add act pass (#28196)

* add fuse_bn_add_act pass
revert-27871-prv-conv-grad-opt
Zhang Ting 4 years ago committed by GitHub
parent 813b2ade34
commit fdc06f2158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -107,7 +107,7 @@ cc_test(exception_holder_test SRCS exception_holder_test.cc )
set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass
fuse_elewise_add_act_pass fuse_bn_act_pass fuse_bn_add_act_pass
multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass

@ -164,6 +164,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_add_act_ops_, "fuse_bn_add_act_pass");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__)
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#else
@ -390,6 +391,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "fuse_bn_add_act_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_bn_add_act_pass is only supported on "
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
@ -416,6 +423,7 @@ USE_PASS(sync_batch_norm_pass);
USE_PASS(fuse_relu_depthwise_conv_pass);
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(fuse_bn_act_pass);
USE_PASS(fuse_bn_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(reduce_mode_multi_devices_pass);

@ -100,6 +100,7 @@ struct BuildStrategy {
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
bool fuse_bn_act_ops_{false};
bool fuse_bn_add_act_ops_{true};
bool fuse_elewise_add_act_ops_{false};
bool enable_auto_fusion_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients

@ -114,6 +114,7 @@ if(WITH_MKLDNN)
endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_bn_add_act_pass SRCS fuse_bn_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )

File diff suppressed because it is too large Load Diff

@ -0,0 +1,75 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the BatchNorm, add and activation.
*/
class Graph;
class Node;
class FuseBatchNormAddActPass : public FusePassBase {
public:
virtual ~FuseBatchNormAddActPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
ir::Graph *FuseBatchNormAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseBatchNormAddActGrad(
ir::Graph *graph,
const std::unordered_set<std::string> &act_grad_types) const;
void LinkOutputsToFuseOp(
Node *op_1, Node *op_2, Node *fused_op,
std::unordered_set<const Node *> *nodes2delete) const;
void LinkInputsToFuseOp(Node *op, Node *fused_op,
std::unordered_set<const Node *> *nodes2delete) const;
std::vector<Node *> ReplaceNode(Node *cur_node, Node *new_node,
const std::vector<Node *> &nodes) const;
void ReLinkNodes(Graph *graph, Node *op_1, Node *op_2, Node *op_3,
Node *fused_op) const;
Node *CreateFusedBatchNormAddActNode(
Graph *g, const Node *act, const Node *add, const Node *bn,
const std::string &bn_x_n, const std::string &add_y_n,
const std::string &bn_scale_n, const std::string &bn_bias_n,
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
const std::string &bn_saved_variance_n,
const std::string &bn_saved_mean_n, const std::string &bn_reserve_space_n,
const std::string &act_out_n) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -93,6 +93,7 @@ void GraphPatternDetector::operator()(Graph *graph,
auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs);
SortSubgraphs(&subgraphs);
RemoveOverlappedMatch(&subgraphs);
ValidateByNodeRole(&subgraphs);
@ -302,6 +303,46 @@ void GraphPatternDetector::UniquePatterns(
*subgraphs = result;
}
void GraphPatternDetector::SortSubgraphs(
std::vector<GraphPatternDetector::subgraph_t> *subgraphs) {
if (subgraphs->empty()) return;
bool has_bn_add_act = false;
for (auto &subgraph : *subgraphs) {
for (auto &item : subgraph) {
if (item.first->name().find("bn_add_act") != std::string::npos) {
has_bn_add_act = true;
break;
}
}
}
if (!has_bn_add_act) {
return;
}
std::sort(
subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &a,
const GraphPatternDetector::subgraph_t &b) {
for (auto &item : a) {
if (item.first->name().find("bn_add_act") != std::string::npos &&
item.first->name().find("bn_reserve_space") !=
std::string::npos) {
auto it_b = b.find(item.first);
if (it_b != b.end()) {
if (item.second->Name() != it_b->second->Name()) {
return item.second->Name() < it_b->second->Name();
} else {
return false;
}
} else {
return false;
}
}
}
return false;
});
}
void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t> *subgraphs) {
std::vector<subgraph_t> result;
@ -1208,6 +1249,151 @@ PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) {
return act_out;
}
PDNode *patterns::BatchNormAddAct::operator()(
paddle::framework::ir::PDNode *bn_x_var,
std::unordered_set<std::string> act_types) {
bn_x_var->assert_is_op_input("batch_norm", "X")
->assert_var_dtype(proto::VarType::FP16);
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->assert_is_op_input("batch_norm", "Scale");
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->assert_is_op_input("batch_norm", "Bias");
auto *bn = pattern->NewNode(batch_norm_repr())
->assert_is_op("batch_norm")
->assert_is_not_op_input("MomentumTensor")
->assert_op_attr<bool>("is_test", false)
->assert_op_attr<bool>("use_global_stats", false)
->assert_op_attr<std::string>("data_layout", "NHWC");
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->assert_is_op_output("batch_norm", "MeanOut");
auto *bn_variance_out_var =
pattern->NewNode(bn_variance_out_repr())
->assert_is_op_output("batch_norm", "VarianceOut");
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->assert_is_op_output("batch_norm", "SavedVariance");
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_output("batch_norm", "SavedMean");
auto *bn_reserve_space =
pattern->NewNode(bn_reserve_space_repr())
->assert_is_op_output("batch_norm", "ReserveSpace");
auto *bn_out_var = pattern->NewNode(bn_out_repr())
->assert_is_op_output("batch_norm", "Y")
->assert_var_dtype(proto::VarType::FP16);
bn_out_var->assert_is_op_input("elementwise_add");
auto *elewise_add =
pattern->NewNode(elewise_add_repr())->assert_is_op("elementwise_add");
auto *elewise_add_in_var = pattern->NewNode(elewise_add_in_repr())
->assert_is_not_ctrl_var()
->assert_is_op_input("elementwise_add")
->assert_var_dtype(proto::VarType::FP16);
auto *elewise_add_out_var =
pattern->NewNode(elewise_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_has_n_outputs(1);
elewise_add_out_var->AsIntermediate()->assert_is_ops_input(act_types);
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
bn->LinksFrom({bn_x_var, bn_scale_var, bn_bias_var})
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
bn_saved_mean_var, bn_reserve_space, bn_out_var});
elewise_add->LinksFrom({elewise_add_in_var, bn_out_var})
.LinksTo({elewise_add_out_var});
act->LinksFrom({elewise_add_out_var}).LinksTo({act_out_var});
return act_out_var;
}
PDNode *patterns::BatchNormAddActGrad::operator()(
paddle::framework::ir::PDNode *d_act_out_var,
std::unordered_set<std::string> act_grad_types) {
auto *act_grad =
pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types);
auto *elewise_add_grad = pattern->NewNode(elewise_add_grad_repr())
->assert_is_op("elementwise_add_grad");
auto *bn_grad = pattern->NewNode(batch_norm_grad_repr())
->assert_is_op("batch_norm_grad")
->assert_op_attr<bool>("use_global_stats", false)
->assert_op_attr<std::string>("data_layout", "NHWC");
auto *act_out_var = pattern->NewNode(act_out_repr())
->assert_is_ops_input(act_grad_types, "Out");
auto *d_act_x_var =
pattern->NewNode(d_act_x_repr())
->assert_is_ops_output(act_grad_types, GradVarName("X"))
->assert_has_n_outputs(1); // d_act_x
d_act_x_var->AsIntermediate()->assert_is_op_input("elementwise_add_grad");
auto *d_elewise_add_in_var =
pattern->NewNode(d_elewise_add_in_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("elementwise_add_grad")
->assert_var_dtype(proto::VarType::FP16); // d_add_in_1
auto *d_bn_out_var =
pattern->NewNode(d_bn_out_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("elementwise_add_grad")
->assert_var_dtype(proto::VarType::FP16); // d_add_in_2
d_bn_out_var->assert_is_op_input("batch_norm_grad", GradVarName("Y"));
auto *bn_x_var = pattern->NewNode(bn_x_repr())
->assert_is_op_input("batch_norm_grad", "X")
->assert_var_dtype(proto::VarType::FP16);
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->assert_is_op_input("batch_norm_grad", "Scale");
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->assert_is_op_input("batch_norm_grad", "Bias");
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_input("batch_norm_grad", "SavedMean");
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->assert_is_op_input("batch_norm_grad", "SavedVariance");
auto *bn_reserve_space =
pattern->NewNode(bn_reserve_space_repr())
->assert_is_op_input("batch_norm_grad", "ReserveSpace");
auto *d_bn_x_var =
pattern->NewNode(d_bn_x_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("X"))
->assert_var_dtype(proto::VarType::FP16);
auto *d_bn_scale_var =
pattern->NewNode(d_bn_scale_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("Scale"));
auto *d_bn_bias_var =
pattern->NewNode(d_bn_bias_repr())
->assert_is_not_ctrl_var()
->assert_is_op_output("batch_norm_grad", GradVarName("Bias"));
act_grad->LinksFrom({d_act_out_var, act_out_var}).LinksTo({d_act_x_var});
elewise_add_grad->LinksFrom({d_act_x_var})
.LinksTo({d_elewise_add_in_var, d_bn_out_var});
bn_grad
->LinksFrom({bn_x_var, d_bn_out_var, bn_scale_var, bn_bias_var,
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
return bn_grad;
}
PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) {

@ -294,6 +294,12 @@ class GraphPatternDetector {
// Remove duplicate patterns.
void UniquePatterns(std::vector<subgraph_t>* subgraphs);
// Sort subgraphs, sort subgraphs by the specified node so that
// the removed forward and backward subgraphs are corresponding
// when two subgraphs are overlapped. Note: this function is
// currently only used for bn_add_act, refer to PR28196 for details.
void SortSubgraphs(std::vector<subgraph_t>* subgraphs);
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// The intermediate PDNodes will be removed, so can't shared by multiple
// patterns.
@ -685,6 +691,72 @@ struct BatchNormActOneDNN : public PatternBase {
PATTERN_DECL_NODE(act_out);
};
// The following pattern is used to fuse batch_norm, elewise_add, and act
// formula: act(bn(x) + z)
// op: batch_norm + elewise_add + act
struct BatchNormAddAct : public PatternBase {
BatchNormAddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_add_act") {}
PDNode* operator()(PDNode* x, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(batch_norm);
PATTERN_DECL_NODE(elewise_add);
PATTERN_DECL_NODE(act);
// declare variable node's name
// BN inputs
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_bias);
// BN outputs
PATTERN_DECL_NODE(bn_mean_out);
PATTERN_DECL_NODE(bn_variance_out);
PATTERN_DECL_NODE(bn_saved_variance);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_reserve_space);
PATTERN_DECL_NODE(bn_out);
// Elewise_Add input
PATTERN_DECL_NODE(elewise_add_in);
// Elewise_Add output
PATTERN_DECL_NODE(elewise_add_out);
// ACT output
PATTERN_DECL_NODE(act_out);
};
// the backward of act(bn(x) + z)
// op: batch_norm_grad + elewise_add_grad + act_grad
struct BatchNormAddActGrad : public PatternBase {
BatchNormAddActGrad(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_add_act_grad") {}
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// elewise_add_grad: in["Out@GRAD"], out["X@GRAD", "Y@GRAD"]
// bn_grad: in["X", "Z", "Y@GRAD", "Scale", "Bias", "SavedMean",
// "SavedVariance",
// "ReserveSpace"],
// out["X@GRAD", "Z@GRAD", "Scale@GRAD", "Bias@GRAD"]
PDNode* operator()(PDNode* x, std::unordered_set<std::string> act_grad_types);
// declare operator node's name
PATTERN_DECL_NODE(act_grad);
PATTERN_DECL_NODE(elewise_add_grad);
PATTERN_DECL_NODE(batch_norm_grad);
// declare variable node's name
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(d_act_x);
PATTERN_DECL_NODE(d_elewise_add_in);
PATTERN_DECL_NODE(d_bn_out);
PATTERN_DECL_NODE(bn_x);
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_bias);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_saved_variance);
PATTERN_DECL_NODE(bn_reserve_space);
PATTERN_DECL_NODE(d_bn_x);
PATTERN_DECL_NODE(d_bn_scale);
PATTERN_DECL_NODE(d_bn_bias);
};
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// op: elementwise_add + act

@ -186,8 +186,6 @@ void FusedBatchNormAddActGradOp::InferShape(
// check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"FusedBatchNormAddActGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",

@ -188,7 +188,6 @@ class FusedBatchNormAddActGradKernel<platform::CUDADeviceContext, T>
std::string act_type = ctx.Attr<std::string>("act_type");
const auto *x = ctx.Input<Tensor>("X");
const auto *z = ctx.Input<Tensor>("Z");
const auto *y = ctx.Input<Tensor>("Y");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");

@ -61,7 +61,6 @@ class FusedBatchNormAddActGradOpMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Z", this->Input("Z"));
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));

@ -2500,6 +2500,31 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = static.BuildStrategy()
build_strategy.fuse_bn_act_ops = True
)DOC")
.def_property(
"fuse_bn_add_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_add_act_ops_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(), true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fuse_bn_add_act_ops_ = b;
},
R"DOC((bool, optional): fuse_bn_add_act_ops indicate whether
to fuse batch_norm, elementwise_add and activation_op,
it may make the execution faster. Default is True
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fuse_bn_add_act_ops = True
)DOC")
.def_property(
"enable_auto_fusion",
[](const BuildStrategy &self) { return self.enable_auto_fusion_; },

@ -331,6 +331,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_api)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
list(REMOVE_ITEM TEST_OPS test_fuse_bn_add_act_pass)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while)
list(REMOVE_ITEM TEST_OPS test_conv3d_transpose_op)
@ -515,6 +516,7 @@ py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_para
py_test_modules(test_data_norm_op MODULES test_data_norm_op)
py_test_modules(test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000)
py_test_modules(test_fuse_bn_add_act_pass MODULES test_fuse_bn_add_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000)
# NOTE: These unittests will appear NaN steadily in windows CI. After analysis,
# it is found that windows CI will run all the training unittests with the ON_INFER option turned on,

@ -21,6 +21,8 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid import core
paddle.enable_static()
@unittest.skipIf(not core.is_compiled_with_cuda(),
"Paddle core is not compiled with CUDA")
@ -163,12 +165,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
iters = 5
batch_size = 16
# build_fused_program
# build_fused_program: turn on fuse_bn_add_act_ops
main_program = fluid.Program()
startup_program = fluid.Program()
x, y, loss = self.build_fused_program(main_program, startup_program,
use_cuda)
x, y, loss = self.build_origin_program(main_program, startup_program,
use_cuda)
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
build_strategy_fused = fluid.BuildStrategy()
build_strategy_fused.fuse_bn_add_act_ops = True
binary_fused = fluid.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy_fused)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
exe = fluid.Executor(place)
@ -178,17 +184,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
exe.run(startup_program)
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(main_program,
loss_v = exe.run(binary_fused,
feed=feeder.feed(data),
fetch_list=[loss])
loss_vals_fused.append(loss_v[0][0])
# build_origin_program
main_program = fluid.Program()
startup_program = fluid.Program()
x, y, loss = self.build_origin_program(main_program, startup_program,
use_cuda)
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
# build_origin_program: turn off fused_bn_act_ops
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_bn_add_act_ops = False
binary = fluid.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
loss_vals = []
@ -197,7 +202,7 @@ class TestFusedBnAddActAPI(unittest.TestCase):
exe.run(startup_program)
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(main_program,
loss_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss])
loss_vals.append(loss_v[0][0])
@ -210,6 +215,25 @@ class TestFusedBnAddActAPI(unittest.TestCase):
place = fluid.CUDAPlace(0)
self.check(place, use_cuda=True)
def test_fuse_bn_add_act_API(self):
# build_fused_program: use fused_bn_add_act python API
main_program = fluid.Program()
startup_program = fluid.Program()
place = fluid.CUDAPlace(0)
x, y, loss = self.build_fused_program(
main_program, startup_program, use_cuda=True)
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16)
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
for _ in range(5):
data = next(train_reader())
loss_v = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save