Merge pull request #15455 from wzzju/graph_quantization

Graph quantization pass. TODO(Add public API comments.)
inference-pre-release-gpu
Zhen Wang 6 years ago committed by GitHub
commit 58727e8e6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
@ -243,3 +244,4 @@ USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(graph_to_program_pass);

@ -28,10 +28,14 @@ std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
attr);
}
auto* native_graph = graph.get();
auto applied_graph = ApplyImpl(std::move(graph));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
PADDLE_ENFORCE(applied_graph.get() == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)");
applied_ = true;
return applied_graph;
}

@ -15,7 +15,9 @@
#include "paddle/fluid/pybind/ir.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
@ -24,6 +26,7 @@
namespace py = pybind11;
using paddle::framework::ir::Graph;
using paddle::framework::ir::Node;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc;
@ -32,6 +35,7 @@ using pybind11::return_value_policy;
namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
@ -42,6 +46,8 @@ void BindGraph(py::module *m) {
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
.def("set",
@ -57,6 +63,17 @@ void BindGraph(py::module *m) {
[](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
return self.Set(attr_name,
new std::unordered_set<const Node *>(attr));
})
.def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node",
@ -85,12 +102,52 @@ void BindNode(py::module *m) {
py::class_<Node> node(*m, "Node");
node.def("name", &Node::Name)
.def("node_type", &Node::NodeType)
.def("var", &Node::Var)
.def("op", &Node::Op)
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("inputs_remove",
[](Node &self, int node_id) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if ((*it)->id() == node_id) {
self.inputs.erase(it);
}
}
})
.def("inputs_remove",
[](Node &self, Node &node) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if (*it == &node) {
self.inputs.erase(it);
}
}
})
.def("inputs_append",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("outputs_remove",
[](Node &self, int node_id) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if ((*it)->id() == node_id) {
self.outputs.erase(it);
}
}
})
.def("outputs_remove",
[](Node &self, Node &node) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if (*it == &node) {
self.outputs.erase(it);
}
}
})
.def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);

@ -228,7 +228,7 @@ void BindBlockDesc(pybind11::module *m) {
void BindVarDsec(pybind11::module *m) {
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
var_desc
var_desc.def(pybind11::init<const std::string &>())
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
.def("set_name", &pd::VarDesc::SetName)
.def("set_shape", &pd::VarDesc::SetShape)

@ -788,21 +788,33 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
m.def("get_pass", [](const py::bytes &binary_str) {
std::string pass_type(binary_str);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
});
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
.def("set",
[](ir::Pass &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def(
"set_str",
"set",
[](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
})
.def("set_int", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
std::unique_ptr<ir::Graph> origin_graph(graph.get());
auto optim_graph = self.Apply(std::move(origin_graph));
graph.reset(optim_graph.release());
optim_graph.release();
});
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(

@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import subprocess
from ....framework import Program
from ....framework import Block
from .... import core
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']

@ -0,0 +1,20 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import quantization_pass
from .quantization_pass import *
__all__ = quantization_pass.__all__

@ -0,0 +1,175 @@
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import random
import numpy as np
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid import core
def linear_fc(num):
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self):
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y']
}
self.quantizable_grad_op_inputs = {
'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y']
}
def check_program(self, transform_pass, program):
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
# check forward
if op.type in self.quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
quantized_ops.add(arg_name)
for op in block.ops:
# check backward
if op.type in self.quantizable_grad_op_inputs:
for pname in self.quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.linear_fc_quant('abs_max')
def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.linear_fc_quant('range_abs_max')
def residual_block_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.residual_block_quant('abs_max')
def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.residual_block_quant('range_abs_max')
if __name__ == '__main__':
unittest.main()

@ -23,6 +23,7 @@ import traceback
import six
import numpy as np
import subprocess
from .. import compat as cpt
from .proto import framework_pb2
@ -1512,6 +1513,154 @@ class Block(object):
return ret_var
class IrGraph(object):
"""
IrGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def __init__(self, graph, for_test=False):
"""
Construct the IrGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph
self._for_test = for_test
def is_test(self):
return self._for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_vars(self):
return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self):
return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc):
return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs):
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node):
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out):
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'Th two arguments must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc)
return program
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
"""
if isinstance(val, Block):
desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string())
else:
desc._set_attr(name, val)
class Program(object):
"""
Python Program. Beneath it is a ProgramDesc, which is used for
@ -1936,6 +2085,23 @@ class Program(object):
p._sync_with_cpp()
return p
@staticmethod
def _construct_from_desc(desc):
"""
Construct a program from program desc.
Args:
desc(core.ProgramDesc): The program desc for constructing.
Returns:
Program: A program.
"""
p = Program()
p.desc = desc
p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())]
p._sync_with_cpp()
return p
@property
def random_seed(self):
"""

@ -123,7 +123,7 @@ class TestDistRunnerBase(object):
pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2":
build_stra.num_trainers = len(args.endpoints.split(","))

@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase):
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass")
viz_pass.set("graph_viz_path", "/tmp/test_viz_pass")
self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(),

@ -113,6 +113,7 @@ packages=['paddle',
'paddle.fluid.contrib.slim.core',
'paddle.fluid.contrib.slim.graph',
'paddle.fluid.contrib.slim.prune',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']

Loading…
Cancel
Save