Feature/op_fuse_pass (#12440)
* Add Preface * Add demo code * Save file * Refine code * seems can work * use elementwise strategy * Use ElementwiseComputeEx * Add comments * extract functions from operator * Refine code * Follow comment * code refine * add op_fuse pass * add backward * code refine * use TopologySortOperations * follow comments * refine IsFusible * code enhance * fix op_fusion_pass * refine code * refine fuse_elemwise_act_op * adjust the input and output * refine logic * add intermediate_edge * disable inplace * follow comments * refine logic * follow comments * Remove the removable IntermediateOut * change strategy * code refine * enable fuse backward * code refine * code refine * rename unit test * follow commentsfix-readme
parent
5acdbbb42f
commit
d402234ba8
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Fuse the ElewiseAdd and activation
|
||||
*/
|
||||
class FuseElewiseAddActPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~FuseElewiseAddActPass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
|
||||
std::unique_ptr<ir::Graph> FuseElewiseAddAct(
|
||||
std::unique_ptr<ir::Graph> graph,
|
||||
const std::unordered_set<std::string> &act_types) const;
|
||||
|
||||
std::unique_ptr<ir::Graph> FuseActElewiseAdd(
|
||||
std::unique_ptr<ir::Graph> graph,
|
||||
const std::unordered_set<std::string> &act_types) const;
|
||||
|
||||
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad(
|
||||
std::unique_ptr<ir::Graph> graph,
|
||||
const std::unordered_set<std::string> &act_types) const;
|
||||
|
||||
/**
|
||||
* Remove the removable intermediate_out.
|
||||
* - If the intermediate_out is only used by the backward op, but the
|
||||
* backward op doesn't use intermediate_out.
|
||||
* - If the intermediate_out_grad is not used by any op.
|
||||
*/
|
||||
void RemoveIntermediateOut(Graph *graph) const;
|
||||
|
||||
std::vector<Node *> ReplaceNode(Node *cur_node, Node *new_node,
|
||||
const std::vector<Node *> &nodes) const;
|
||||
|
||||
std::vector<Node *> RemoveNode(Node *trg_node,
|
||||
const std::vector<Node *> &nodes) const;
|
||||
|
||||
void ReLinkNodes(Graph *graph, const Node *intermediate_out, Node *op_1,
|
||||
Node *op_2, Node *fused_op) const;
|
||||
Node *CreateFuseElewiseAddActNode(Graph *g, const Node *op_1,
|
||||
const Node *op_2,
|
||||
const std::string &ele_x_n,
|
||||
const std::string &ele_y_n,
|
||||
const std::string &ele_out_n,
|
||||
const std::string &act_out_n) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from parallel_executor_test_base import TestParallelExecutorBase
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.dataset.mnist as mnist
|
||||
import unittest
|
||||
import os
|
||||
|
||||
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
|
||||
|
||||
|
||||
def simple_fc_net(use_feed):
|
||||
if use_feed:
|
||||
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
else:
|
||||
reader = fluid.layers.open_files(
|
||||
filenames=[MNIST_RECORDIO_FILE],
|
||||
shapes=[[-1, 784], [-1, 1]],
|
||||
lod_levels=[0, 0],
|
||||
dtypes=['float32', 'int64'])
|
||||
reader = fluid.layers.io.double_buffer(reader)
|
||||
img, label = fluid.layers.read_file(reader)
|
||||
hidden = img
|
||||
for _ in range(4):
|
||||
hidden = fluid.layers.fc(
|
||||
hidden,
|
||||
size=200,
|
||||
act='relu',
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=1.0)))
|
||||
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
loss = fluid.layers.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
def fc_with_batchnorm(use_feed):
|
||||
if use_feed:
|
||||
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
else:
|
||||
reader = fluid.layers.open_files(
|
||||
filenames=[MNIST_RECORDIO_FILE],
|
||||
shapes=[[-1, 784], [-1, 1]],
|
||||
lod_levels=[0, 0],
|
||||
dtypes=['float32', 'int64'])
|
||||
reader = fluid.layers.io.double_buffer(reader)
|
||||
img, label = fluid.layers.read_file(reader)
|
||||
|
||||
hidden = img
|
||||
for _ in range(2):
|
||||
hidden = fluid.layers.fc(
|
||||
hidden,
|
||||
size=200,
|
||||
act='relu',
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=1.0)))
|
||||
|
||||
hidden = fluid.layers.batch_norm(input=hidden)
|
||||
|
||||
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
loss = fluid.layers.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class TestMNIST(TestParallelExecutorBase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
os.environ['CPU_NUM'] = str(4)
|
||||
# Convert mnist to recordio file
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
reader = paddle.batch(mnist.train(), batch_size=4)
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[ # order is image and label
|
||||
fluid.layers.data(
|
||||
name='image', shape=[784]),
|
||||
fluid.layers.data(
|
||||
name='label', shape=[1], dtype='int64'),
|
||||
],
|
||||
place=fluid.CPUPlace())
|
||||
fluid.recordio_writer.convert_reader_to_recordio_file(
|
||||
MNIST_RECORDIO_FILE, reader, feeder)
|
||||
|
||||
def _init_data(self, random=True):
|
||||
np.random.seed(5)
|
||||
if random:
|
||||
img = np.random.random(size=[32, 784]).astype(np.float32)
|
||||
else:
|
||||
img = np.ones(shape=[32, 784], dtype='float32')
|
||||
label = np.ones(shape=[32, 1], dtype='int64')
|
||||
return img, label
|
||||
|
||||
def _compare_fuse_elewise_add_act_ops(self,
|
||||
model,
|
||||
use_cuda,
|
||||
random_data=True):
|
||||
if use_cuda and not core.is_compiled_with_cuda():
|
||||
return
|
||||
img, label = self._init_data(random_data)
|
||||
|
||||
def _optimizer(learning_rate=1e-6):
|
||||
optimizer = fluid.optimizer.SGD(
|
||||
learning_rate=learning_rate,
|
||||
regularization=fluid.regularizer.L2Decay(1e-6))
|
||||
return optimizer
|
||||
|
||||
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
|
||||
model,
|
||||
feed_dict={"image": img,
|
||||
"label": label},
|
||||
use_cuda=use_cuda,
|
||||
fuse_elewise_add_act_ops=False,
|
||||
memory_opt=False,
|
||||
optimizer=_optimizer)
|
||||
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
|
||||
model,
|
||||
feed_dict={"image": img,
|
||||
"label": label},
|
||||
use_cuda=use_cuda,
|
||||
fuse_elewise_add_act_ops=True,
|
||||
memory_opt=False,
|
||||
optimizer=_optimizer)
|
||||
|
||||
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
|
||||
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
|
||||
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
|
||||
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
|
||||
|
||||
def test_simple_fc_with_fuse_op(self):
|
||||
self._compare_fuse_elewise_add_act_ops(simple_fc_net, True)
|
||||
self._compare_fuse_elewise_add_act_ops(simple_fc_net, False)
|
||||
|
||||
def test_batchnorm_fc_with_fuse_op(self):
|
||||
self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, True)
|
||||
self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue