Memory optimization of depthwise conv op and group norm op (#15313)
* mem opt * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * refine code test=develop * refine code test=develop * refine code test=develop * refine code test=develop * refine with cub test=develop * fix mkldnn test && remove comments && test=develop * polish code && test=develop * add only_forward test && test=developinference-pre-release-gpu
parent
9252aa41f5
commit
9f8f0fc2d3
@ -0,0 +1,159 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/operator.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const {
|
||||||
|
graph = FuseReluDepthwiseConv(std::move(graph), true);
|
||||||
|
graph = FuseReluDepthwiseConv(std::move(graph), false);
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
|
||||||
|
std::unique_ptr<ir::Graph> graph, bool only_forward) const {
|
||||||
|
PADDLE_ENFORCE(graph.get());
|
||||||
|
if (only_forward)
|
||||||
|
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get());
|
||||||
|
else
|
||||||
|
FusePassBase::Init("relu_depthwise_conv", graph.get());
|
||||||
|
/*
|
||||||
|
x ---act--> y ---layer-> z
|
||||||
|
+----------+
|
||||||
|
↓ ↓
|
||||||
|
x' <--act'--- y' <-layer'--- z'
|
||||||
|
|
||||||
|
fuse to:
|
||||||
|
|
||||||
|
x ---act-layer-> z
|
||||||
|
|
|
||||||
|
↓
|
||||||
|
x' <--act-layer'--- z'
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
GraphPatternDetector gpd;
|
||||||
|
auto *pattern = gpd.mutable_pattern();
|
||||||
|
std::string act_type = "relu";
|
||||||
|
std::string layer_type = "depthwise_conv2d";
|
||||||
|
auto *x = pattern->NewNode("x")->AsInput();
|
||||||
|
auto *y = pattern->NewNode("y")->AsIntermediate();
|
||||||
|
auto *z = pattern->NewNode("z")->AsOutput();
|
||||||
|
PDNode *xg = nullptr;
|
||||||
|
PDNode *yg = nullptr;
|
||||||
|
PDNode *zg = nullptr;
|
||||||
|
if (!only_forward) {
|
||||||
|
xg = pattern->NewNode("xg")->AsOutput();
|
||||||
|
yg = pattern->NewNode("yg")->AsIntermediate();
|
||||||
|
zg = pattern->NewNode("zg")->AsInput();
|
||||||
|
}
|
||||||
|
|
||||||
|
PDNode *act_g = nullptr;
|
||||||
|
PDNode *layer_g = nullptr;
|
||||||
|
auto *act = pattern->NewNode("act")->assert_is_op(act_type);
|
||||||
|
auto *layer = pattern->NewNode("layer")->assert_is_op(layer_type);
|
||||||
|
if (!only_forward) {
|
||||||
|
act_g = pattern->NewNode("act_g")->assert_is_op(act_type + "_grad");
|
||||||
|
layer_g = pattern->NewNode("layer_g")->assert_is_op(layer_type + "_grad");
|
||||||
|
}
|
||||||
|
|
||||||
|
act->LinksFrom({x}).LinksTo({y});
|
||||||
|
layer->LinksFrom({y}).LinksTo({z});
|
||||||
|
if (!only_forward) {
|
||||||
|
layer_g->LinksFrom({y, zg}).LinksTo({yg});
|
||||||
|
act_g->LinksFrom({y, yg}).LinksTo({xg});
|
||||||
|
}
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
std::unordered_set<const Node *> need_removed_nodes;
|
||||||
|
|
||||||
|
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
||||||
|
Graph *g) {
|
||||||
|
VLOG(4) << "handle FuseReluDepthwiseConv fuse";
|
||||||
|
// 1. turn on fuse option
|
||||||
|
auto *layer_op = subgraph.at(layer)->Op();
|
||||||
|
layer_op->SetAttr("use_cudnn", false);
|
||||||
|
layer_op->SetAttr("fuse_relu_before_depthwise_conv", true);
|
||||||
|
|
||||||
|
OpDesc *layer_g_op = nullptr;
|
||||||
|
if (!only_forward) {
|
||||||
|
layer_g_op = subgraph.at(layer_g)->Op();
|
||||||
|
layer_g_op->SetAttr("use_cudnn", false);
|
||||||
|
layer_g_op->SetAttr("fuse_relu_before_depthwise_conv", true);
|
||||||
|
}
|
||||||
|
// 2. connect x to layer and layer_g, layer_g to xg
|
||||||
|
auto *y_var = subgraph.at(y)->Var();
|
||||||
|
auto *x_var = subgraph.at(x)->Var();
|
||||||
|
VarDesc *yg_var = nullptr;
|
||||||
|
VarDesc *xg_var = nullptr;
|
||||||
|
if (!only_forward) {
|
||||||
|
yg_var = subgraph.at(yg)->Var();
|
||||||
|
xg_var = subgraph.at(xg)->Var();
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1);
|
||||||
|
PADDLE_ENFORCE_EQ(layer_op->Input("Input")[0], y_var->Name());
|
||||||
|
layer_op->SetInput("Input", {x_var->Name()});
|
||||||
|
subgraph.at(layer)->inputs.push_back(subgraph.at(x));
|
||||||
|
subgraph.at(x)->outputs.push_back(subgraph.at(layer));
|
||||||
|
VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name();
|
||||||
|
|
||||||
|
if (!only_forward) {
|
||||||
|
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input").size(), 1);
|
||||||
|
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input")[0], y_var->Name());
|
||||||
|
layer_g_op->SetInput("Input", {x_var->Name()});
|
||||||
|
subgraph.at(layer_g)->inputs.push_back(subgraph.at(x));
|
||||||
|
subgraph.at(x)->outputs.push_back(subgraph.at(layer_g));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input")).size(), 1);
|
||||||
|
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input"))[0],
|
||||||
|
yg_var->Name());
|
||||||
|
layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()});
|
||||||
|
subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg));
|
||||||
|
subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g));
|
||||||
|
VLOG(4) << "replace " << yg_var->Name() << " -> " << xg_var->Name();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. delete y, yg, act, act_g
|
||||||
|
|
||||||
|
if (only_forward) {
|
||||||
|
need_removed_nodes.insert({subgraph.at(y), subgraph.at(act)});
|
||||||
|
} else {
|
||||||
|
need_removed_nodes.insert({subgraph.at(y), subgraph.at(yg),
|
||||||
|
subgraph.at(act), subgraph.at(act_g)});
|
||||||
|
}
|
||||||
|
count++;
|
||||||
|
};
|
||||||
|
gpd(graph.get(), handler);
|
||||||
|
GraphSafeRemoveNodes(graph.get(), need_removed_nodes);
|
||||||
|
AddStatis(count);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_PASS(fuse_relu_depthwise_conv_pass,
|
||||||
|
paddle::framework::ir::FuseReluDepthwiseConvPass);
|
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||||
|
#include "paddle/fluid/framework/ir/pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Fuse the relu and depthwise conv
|
||||||
|
*/
|
||||||
|
class FuseReluDepthwiseConvPass : public FusePassBase {
|
||||||
|
public:
|
||||||
|
virtual ~FuseReluDepthwiseConvPass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||||
|
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
|
||||||
|
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,149 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from parallel_executor_test_base import TestParallelExecutorBase
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.dataset.mnist as mnist
|
||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
|
||||||
|
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
|
||||||
|
|
||||||
|
|
||||||
|
def norm(*args, **kargs):
|
||||||
|
return fluid.layers.batch_norm(*args, **kargs)
|
||||||
|
|
||||||
|
|
||||||
|
def sep_conv(input, channel, stride, filter, dilation=1, act=None):
|
||||||
|
# with scope('depthwise'):
|
||||||
|
input = fluid.layers.conv2d(
|
||||||
|
input,
|
||||||
|
input.shape[1],
|
||||||
|
filter,
|
||||||
|
stride,
|
||||||
|
groups=input.shape[1],
|
||||||
|
padding=(filter // 2) * dilation,
|
||||||
|
dilation=dilation,
|
||||||
|
use_cudnn=False,
|
||||||
|
bias_attr=False)
|
||||||
|
input = norm(input)
|
||||||
|
if act: input = act(input)
|
||||||
|
# with scope('pointwise'):
|
||||||
|
input = fluid.layers.conv2d(
|
||||||
|
input, channel, 1, 1, groups=1, padding=0, bias_attr=False)
|
||||||
|
input = norm(input)
|
||||||
|
if act: input = act(input)
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def simple_depthwise_net(use_feed):
|
||||||
|
if use_feed:
|
||||||
|
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
|
||||||
|
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||||
|
else:
|
||||||
|
reader = fluid.layers.open_files(
|
||||||
|
filenames=[MNIST_RECORDIO_FILE],
|
||||||
|
shapes=[[-1, 784], [-1, 1]],
|
||||||
|
lod_levels=[0, 0],
|
||||||
|
dtypes=['float32', 'int64'])
|
||||||
|
reader = fluid.layers.io.double_buffer(reader)
|
||||||
|
img, label = fluid.layers.read_file(reader)
|
||||||
|
hidden = fluid.layers.reshape(img, (-1, 1, 28, 28))
|
||||||
|
for _ in range(4):
|
||||||
|
hidden = sep_conv(hidden, channel=200, stride=2, filter=5)
|
||||||
|
hidden = fluid.layers.relu(hidden)
|
||||||
|
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
|
||||||
|
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||||
|
loss = fluid.layers.mean(loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TestMNIST(TestParallelExecutorBase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
os.environ['CPU_NUM'] = str(4)
|
||||||
|
# Convert mnist to recordio file
|
||||||
|
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||||
|
reader = paddle.batch(mnist.train(), batch_size=4)
|
||||||
|
feeder = fluid.DataFeeder(
|
||||||
|
feed_list=[ # order is image and label
|
||||||
|
fluid.layers.data(
|
||||||
|
name='image', shape=[784]),
|
||||||
|
fluid.layers.data(
|
||||||
|
name='label', shape=[1], dtype='int64'),
|
||||||
|
],
|
||||||
|
place=fluid.CPUPlace())
|
||||||
|
fluid.recordio_writer.convert_reader_to_recordio_file(
|
||||||
|
MNIST_RECORDIO_FILE, reader, feeder)
|
||||||
|
|
||||||
|
def _init_data(self, random=True):
|
||||||
|
np.random.seed(5)
|
||||||
|
if random:
|
||||||
|
img = np.random.random(size=[32, 784]).astype(np.float32)
|
||||||
|
else:
|
||||||
|
img = np.ones(shape=[32, 784], dtype='float32')
|
||||||
|
label = np.ones(shape=[32, 1], dtype='int64')
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
def _compare(self, model, use_cuda, random_data=True, only_forward=False):
|
||||||
|
if use_cuda and not core.is_compiled_with_cuda():
|
||||||
|
return
|
||||||
|
img, label = self._init_data(random_data)
|
||||||
|
|
||||||
|
def _optimizer(learning_rate=1e-6):
|
||||||
|
optimizer = fluid.optimizer.SGD(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
regularization=fluid.regularizer.L2Decay(1e-6))
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
if only_forward:
|
||||||
|
_optimizer = None
|
||||||
|
|
||||||
|
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
|
||||||
|
model,
|
||||||
|
feed_dict={"image": img,
|
||||||
|
"label": label},
|
||||||
|
use_cuda=use_cuda,
|
||||||
|
fuse_relu_depthwise_conv=True,
|
||||||
|
use_ir_memory_optimize=True,
|
||||||
|
memory_opt=False,
|
||||||
|
optimizer=_optimizer)
|
||||||
|
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
|
||||||
|
model,
|
||||||
|
feed_dict={"image": img,
|
||||||
|
"label": label},
|
||||||
|
use_cuda=use_cuda,
|
||||||
|
fuse_relu_depthwise_conv=False,
|
||||||
|
memory_opt=False,
|
||||||
|
optimizer=_optimizer)
|
||||||
|
|
||||||
|
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
|
||||||
|
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
|
||||||
|
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
|
||||||
|
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
|
||||||
|
|
||||||
|
def test_simple_depthwise_with_fuse_op(self):
|
||||||
|
self._compare(simple_depthwise_net, True)
|
||||||
|
self._compare(simple_depthwise_net, False)
|
||||||
|
|
||||||
|
def test_simple_depthwise_with_fuse_op_only_forward(self):
|
||||||
|
self._compare(simple_depthwise_net, True, only_forward=True)
|
||||||
|
self._compare(simple_depthwise_net, False, only_forward=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue