Add bn and relu fuse pass (#22048)
* add bn and relu fuse pass * add op attr assert and dtype assert * fix some inputs&&outputs bugs for the fused op and pattern. * add the unittest for fuse_bn_act_pass. test=develop * use normative enforce statements. test=develop * add the cpu test. test=develop * add the support of batch_size=1 for the bn with relu op. test=develop * add the error type for paddle throws. test=develop * add fused_batch_norm_act and fused_batch_norm_act_grad to op_has_unsed_vars_white_list. test=developrevert-22710-feature/integrated_ps_api
parent
0d82baf837
commit
46189b166d
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Fuse the BatchNorm and activation.
|
||||
*/
|
||||
class FuseBatchNormActPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~FuseBatchNormActPass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph *graph) const override;
|
||||
|
||||
ir::Graph *FuseBatchNormAct(
|
||||
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
|
||||
|
||||
ir::Graph *FuseBatchNormActGrad(
|
||||
ir::Graph *graph,
|
||||
const std::unordered_set<std::string> &act_grad_types) const;
|
||||
|
||||
std::vector<Node *> ReplaceNode(Node *cur_node, Node *new_node,
|
||||
const std::vector<Node *> &nodes) const;
|
||||
|
||||
void ReLinkNodes(Graph *graph, const Node *intermediate_out, Node *op_1,
|
||||
Node *op_2, Node *fused_op) const;
|
||||
Node *CreateFusedBatchNormActNode(
|
||||
Graph *g, const Node *act, const Node *bn, const std::string &bn_x_n,
|
||||
const std::string &bn_scale_n, const std::string &bn_bias_n,
|
||||
const std::string &bn_variance_n, const std::string &bn_mean_n,
|
||||
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
|
||||
const std::string &bn_saved_variance_n,
|
||||
const std::string &bn_saved_mean_n, const std::string &bn_reserve_space_n,
|
||||
const std::string &act_out_n) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,109 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "paddle/fluid/framework/grad_op_desc_maker.h"
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class FusedBatchNormActOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override;
|
||||
|
||||
framework::OpKernelType GetKernelTypeForVar(
|
||||
const std::string& var_name, const Tensor& tensor,
|
||||
const framework::OpKernelType& expected_kernel_type) const override;
|
||||
};
|
||||
|
||||
class FusedBatchNormActGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
class FusedBatchNormActOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<T> Apply() const override {
|
||||
std::unique_ptr<T> op(new T());
|
||||
op->SetType(this->ForwardOpType() + "_grad");
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Y", this->Output("Y"));
|
||||
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
|
||||
|
||||
op->SetInput("Scale", this->Input("Scale"));
|
||||
op->SetInput("Bias", this->Input("Bias"));
|
||||
op->SetInput("SavedMean", this->Output("SavedMean"));
|
||||
op->SetInput("SavedVariance", this->Output("SavedVariance"));
|
||||
op->SetInput("ReserveSpace", this->Output("ReserveSpace"));
|
||||
|
||||
op->SetAttrMap(this->Attrs());
|
||||
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
|
||||
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
|
||||
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
class FusedBatchNormActOpInferVarType
|
||||
: public framework::PassInDtypeAndVarTypeToOutput {
|
||||
protected:
|
||||
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
|
||||
const override {
|
||||
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FusedBatchNormActKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FusedBatchNormActGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,121 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import unittest
|
||||
|
||||
|
||||
class TestFuseBatchNormActPass(unittest.TestCase):
|
||||
def build_program(self, main_program, startup_program, use_cuda, seed=1):
|
||||
main_program.random_seed = seed
|
||||
startup_program.random_seed = seed
|
||||
with fluid.program_guard(main_program, startup_program):
|
||||
x = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
|
||||
y = fluid.layers.data(name="y", shape=[1], dtype='int64')
|
||||
hidden1 = fluid.layers.conv2d(
|
||||
input=x,
|
||||
filter_size=3,
|
||||
num_filters=32,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act=None,
|
||||
bias_attr=False,
|
||||
data_format='NHWC')
|
||||
param_attr = fluid.ParamAttr(
|
||||
name='batch_norm_w',
|
||||
initializer=fluid.initializer.Constant(value=1.0))
|
||||
bias_attr = fluid.ParamAttr(
|
||||
name='batch_norm_b',
|
||||
initializer=fluid.initializer.Constant(value=0.0))
|
||||
hidden2 = fluid.layers.batch_norm(
|
||||
input=hidden1,
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr,
|
||||
act='relu',
|
||||
data_layout='NHWC')
|
||||
hidden3 = fluid.layers.fc(input=hidden2, size=128, act='relu')
|
||||
hidden4 = fluid.layers.batch_norm(
|
||||
input=hidden3, act='relu', data_layout='NHWC')
|
||||
prediction = fluid.layers.fc(input=hidden4, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=y)
|
||||
loss = fluid.layers.mean(loss)
|
||||
sgd = fluid.optimizer.SGD(learning_rate=0.001)
|
||||
if use_cuda:
|
||||
sgd = fluid.contrib.mixed_precision.decorate(
|
||||
sgd, use_dynamic_loss_scaling=True, init_loss_scaling=128.0)
|
||||
sgd.minimize(loss)
|
||||
return x, y, loss
|
||||
|
||||
def check(self, place, use_cuda):
|
||||
main_program = fluid.Program()
|
||||
startup_program = fluid.Program()
|
||||
x, y, loss = self.build_program(main_program, startup_program, use_cuda)
|
||||
exe = fluid.Executor(place)
|
||||
iters = 10
|
||||
batch_size = 16
|
||||
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
|
||||
|
||||
# close fused_bn_act_ops
|
||||
build_strategy = fluid.BuildStrategy()
|
||||
build_strategy.fuse_bn_act_ops = False
|
||||
binary = fluid.CompiledProgram(main_program).with_data_parallel(
|
||||
loss_name=loss.name, build_strategy=build_strategy)
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.mnist.train(), batch_size=batch_size)
|
||||
loss_vals = []
|
||||
scope = fluid.Scope()
|
||||
with fluid.scope_guard(scope):
|
||||
exe.run(startup_program)
|
||||
for _ in range(iters):
|
||||
data = next(train_reader())
|
||||
loss_v = exe.run(binary,
|
||||
feed=feeder.feed(data),
|
||||
fetch_list=[loss])
|
||||
loss_vals.append(loss_v[0][0])
|
||||
|
||||
# open fused_bn_act_ops
|
||||
build_strategy_fused = fluid.BuildStrategy()
|
||||
build_strategy_fused.fuse_bn_act_ops = True
|
||||
binary_fused = fluid.CompiledProgram(main_program).with_data_parallel(
|
||||
loss_name=loss.name, build_strategy=build_strategy_fused)
|
||||
train_reader_fused = paddle.batch(
|
||||
paddle.dataset.mnist.train(), batch_size=batch_size)
|
||||
loss_vals_fused = []
|
||||
scope_fused = fluid.Scope()
|
||||
with fluid.scope_guard(scope_fused):
|
||||
exe.run(startup_program)
|
||||
for _ in range(iters):
|
||||
data = next(train_reader_fused())
|
||||
loss_v = exe.run(binary_fused,
|
||||
feed=feeder.feed(data),
|
||||
fetch_list=[loss])
|
||||
loss_vals_fused.append(loss_v[0][0])
|
||||
|
||||
# check loss
|
||||
for i in range(iters):
|
||||
self.assertAlmostEqual(loss_vals[i], loss_vals_fused[i], delta=1e-5)
|
||||
|
||||
def test_fuse_bn_act_pass_cpu(self):
|
||||
place = fluid.CPUPlace()
|
||||
self.check(place, use_cuda=False)
|
||||
|
||||
def test_fuse_bn_act_pass_cuda(self):
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
self.check(place, use_cuda=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue