You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
959 lines
42 KiB
959 lines
42 KiB
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from ..core.strategy import Strategy
|
|
from ..graph import VarWrapper, OpWrapper, GraphWrapper
|
|
from ....framework import Program, program_guard, Parameter
|
|
from .... import layers
|
|
import prettytable as pt
|
|
import numpy as np
|
|
from scipy.optimize import leastsq
|
|
import copy
|
|
import re
|
|
import os
|
|
import pickle
|
|
import logging
|
|
import sys
|
|
|
|
__all__ = ['SensitivePruneStrategy', 'UniformPruneStrategy', 'PruneStrategy']
|
|
|
|
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
|
_logger = logging.getLogger(__name__)
|
|
_logger.setLevel(logging.INFO)
|
|
|
|
|
|
class PruneStrategy(Strategy):
|
|
"""
|
|
The base class of all pruning strategies.
|
|
"""
|
|
|
|
def __init__(self,
|
|
pruner=None,
|
|
start_epoch=0,
|
|
end_epoch=0,
|
|
target_ratio=0.5,
|
|
metric_name=None,
|
|
pruned_params='conv.*_weights'):
|
|
"""
|
|
Args:
|
|
pruner(slim.Pruner): The pruner used to prune the parameters.
|
|
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
|
|
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
|
|
target_ratio(float): The flops ratio to be pruned from current model.
|
|
metric_name(str): The metric used to evaluate the model.
|
|
It should be one of keys in out_nodes of graph wrapper.
|
|
pruned_params(str): The pattern str to match the parameter names to be pruned.
|
|
"""
|
|
super(PruneStrategy, self).__init__(start_epoch, end_epoch)
|
|
self.pruner = pruner
|
|
self.target_ratio = target_ratio
|
|
self.metric_name = metric_name
|
|
self.pruned_params = pruned_params
|
|
self.pruned_list = []
|
|
|
|
def _eval_graph(self, context, sampled_rate=None, cached_id=0):
|
|
"""
|
|
Evaluate the current mode in context.
|
|
Args:
|
|
context(slim.core.Context): The context storing all information used to evaluate the current model.
|
|
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None.
|
|
cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0.
|
|
"""
|
|
results, names = context.run_eval_graph(sampled_rate, cached_id)
|
|
metric = np.mean(results[list(names).index(self.metric_name)])
|
|
return metric
|
|
|
|
def _prune_filters_by_ratio(self,
|
|
scope,
|
|
params,
|
|
ratio,
|
|
place,
|
|
lazy=False,
|
|
only_graph=False,
|
|
param_shape_backup=None,
|
|
param_backup=None):
|
|
"""
|
|
Pruning filters by given ratio.
|
|
Args:
|
|
scope(fluid.core.Scope): The scope used to pruning filters.
|
|
params(list<VarWrapper>): A list of filter parameters.
|
|
ratio(float): The ratio to be pruned.
|
|
place(fluid.Place): The device place of filter parameters.
|
|
lazy(bool): True means setting the pruned elements to zero.
|
|
False means cutting down the pruned elements.
|
|
only_graph(bool): True means only modifying the graph.
|
|
False means modifying graph and variables in scope.
|
|
"""
|
|
if params[0].name() in self.pruned_list[0]:
|
|
return
|
|
param_t = scope.find_var(params[0].name()).get_tensor()
|
|
pruned_idx = self.pruner.cal_pruned_idx(
|
|
params[0].name(), np.array(param_t), ratio, axis=0)
|
|
for param in params:
|
|
assert isinstance(param, VarWrapper)
|
|
param_t = scope.find_var(param.name()).get_tensor()
|
|
if param_backup is not None and (param.name() not in param_backup):
|
|
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
|
|
pruned_param = self.pruner.prune_tensor(
|
|
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
|
|
if not only_graph:
|
|
param_t.set(pruned_param, place)
|
|
ori_shape = param.shape()
|
|
if param_shape_backup is not None and (
|
|
param.name() not in param_shape_backup):
|
|
param_shape_backup[param.name()] = copy.deepcopy(param.shape())
|
|
new_shape = list(param.shape())
|
|
new_shape[0] = pruned_param.shape[0]
|
|
param.set_shape(new_shape)
|
|
_logger.debug(
|
|
'|----------------------------------------+----+------------------------------+------------------------------|'
|
|
)
|
|
_logger.debug('|{:^40}|{:^4}|{:^30}|{:^30}|'.format(
|
|
str(param.name()),
|
|
str(ratio), str(ori_shape), str(param.shape())))
|
|
self.pruned_list[0].append(param.name())
|
|
return pruned_idx
|
|
|
|
def _prune_parameter_by_idx(self,
|
|
scope,
|
|
params,
|
|
pruned_idx,
|
|
pruned_axis,
|
|
place,
|
|
lazy=False,
|
|
only_graph=False,
|
|
param_shape_backup=None,
|
|
param_backup=None):
|
|
"""
|
|
Pruning parameters in given axis.
|
|
Args:
|
|
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
|
|
params(VarWrapper): The parameter to be pruned.
|
|
pruned_idx(list): The index of elements to be pruned.
|
|
pruned_axis(int): The pruning axis.
|
|
place(fluid.Place): The device place of filter parameters.
|
|
lazy(bool): True means setting the pruned elements to zero.
|
|
False means cutting down the pruned elements.
|
|
only_graph(bool): True means only modifying the graph.
|
|
False means modifying graph and variables in scope.
|
|
"""
|
|
if params[0].name() in self.pruned_list[pruned_axis]:
|
|
return
|
|
for param in params:
|
|
assert isinstance(param, VarWrapper)
|
|
param_t = scope.find_var(param.name()).get_tensor()
|
|
if param_backup is not None and (param.name() not in param_backup):
|
|
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
|
|
pruned_param = self.pruner.prune_tensor(
|
|
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
|
|
if not only_graph:
|
|
param_t.set(pruned_param, place)
|
|
ori_shape = param.shape()
|
|
|
|
if param_shape_backup is not None and (
|
|
param.name() not in param_shape_backup):
|
|
param_shape_backup[param.name()] = copy.deepcopy(param.shape())
|
|
new_shape = list(param.shape())
|
|
new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
|
|
param.set_shape(new_shape)
|
|
_logger.debug(
|
|
'|----------------------------------------+----+------------------------------+------------------------------|'
|
|
)
|
|
_logger.debug('|{:^40}|{:^4}|{:^30}|{:^30}|'.format(
|
|
str(param.name()),
|
|
str(pruned_axis), str(ori_shape), str(param.shape())))
|
|
self.pruned_list[pruned_axis].append(param.name())
|
|
|
|
def _forward_search_related_op(self, graph, param):
|
|
"""
|
|
Forward search operators that will be affected by pruning of param.
|
|
Args:
|
|
graph(GraphWrapper): The graph to be searched.
|
|
param(VarWrapper): The current pruned parameter.
|
|
Returns:
|
|
list<OpWrapper>: A list of operators.
|
|
"""
|
|
assert isinstance(param, VarWrapper)
|
|
visited = {}
|
|
for op in graph.ops():
|
|
visited[op.idx()] = False
|
|
stack = []
|
|
for op in graph.ops():
|
|
if (not op.is_bwd_op()) and (param in op.all_inputs()):
|
|
stack.append(op)
|
|
visit_path = []
|
|
while len(stack) > 0:
|
|
top_op = stack[len(stack) - 1]
|
|
if visited[top_op.idx()] == False:
|
|
visit_path.append(top_op)
|
|
visited[top_op.idx()] = True
|
|
next_ops = None
|
|
if top_op.type() == "conv2d" and param not in top_op.all_inputs():
|
|
next_ops = None
|
|
elif top_op.type() == "mul":
|
|
next_ops = None
|
|
else:
|
|
next_ops = self._get_next_unvisited_op(graph, visited, top_op)
|
|
if next_ops == None:
|
|
stack.pop()
|
|
else:
|
|
stack += next_ops
|
|
return visit_path
|
|
|
|
def _get_next_unvisited_op(self, graph, visited, top_op):
|
|
"""
|
|
Get next unvisited adjacent operators of given operators.
|
|
Args:
|
|
graph(GraphWrapper): The graph used to search.
|
|
visited(list): The ids of operators that has been visited.
|
|
top_op: The given operator.
|
|
Returns:
|
|
list<OpWrapper>: A list of operators.
|
|
"""
|
|
assert isinstance(top_op, OpWrapper)
|
|
next_ops = []
|
|
for op in graph.next_ops(top_op):
|
|
if (visited[op.idx()] == False) and (not op.is_bwd_op()):
|
|
next_ops.append(op)
|
|
return next_ops if len(next_ops) > 0 else None
|
|
|
|
def _get_accumulator(self, graph, param):
|
|
"""
|
|
Get accumulators of given parameter. The accumulator was created by optimizer.
|
|
Args:
|
|
graph(GraphWrapper): The graph used to search.
|
|
param(VarWrapper): The given parameter.
|
|
Returns:
|
|
list<VarWrapper>: A list of accumulators which are variables.
|
|
"""
|
|
assert isinstance(param, VarWrapper)
|
|
params = []
|
|
for op in param.outputs():
|
|
if op.is_opt_op():
|
|
for out_var in op.all_outputs():
|
|
if graph.is_persistable(out_var) and out_var.name(
|
|
) != param.name():
|
|
params.append(out_var)
|
|
return params
|
|
|
|
def _forward_pruning_ralated_params(self,
|
|
graph,
|
|
scope,
|
|
param,
|
|
place,
|
|
ratio=None,
|
|
pruned_idxs=None,
|
|
lazy=False,
|
|
only_graph=False,
|
|
param_backup=None,
|
|
param_shape_backup=None):
|
|
"""
|
|
Pruning all the parameters affected by the pruning of given parameter.
|
|
Args:
|
|
graph(GraphWrapper): The graph to be searched.
|
|
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
|
|
param(VarWrapper): The given parameter.
|
|
place(fluid.Place): The device place of filter parameters.
|
|
ratio(float): The target ratio to be pruned.
|
|
pruned_idx(list): The index of elements to be pruned.
|
|
lazy(bool): True means setting the pruned elements to zero.
|
|
False means cutting down the pruned elements.
|
|
only_graph(bool): True means only modifying the graph.
|
|
False means modifying graph and variables in scope.
|
|
"""
|
|
assert isinstance(
|
|
graph,
|
|
GraphWrapper), "graph must be instance of slim.core.GraphWrapper"
|
|
assert isinstance(
|
|
param, VarWrapper), "param must be instance of slim.core.VarWrapper"
|
|
|
|
if param.name() in self.pruned_list[0]:
|
|
return
|
|
related_ops = self._forward_search_related_op(graph, param)
|
|
|
|
if ratio is None:
|
|
assert pruned_idxs is not None
|
|
self._prune_parameter_by_idx(
|
|
scope, [param] + self._get_accumulator(graph, param),
|
|
pruned_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
|
|
else:
|
|
pruned_idxs = self._prune_filters_by_ratio(
|
|
scope, [param] + self._get_accumulator(graph, param),
|
|
ratio,
|
|
place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
corrected_idxs = pruned_idxs[:]
|
|
|
|
for idx, op in enumerate(related_ops):
|
|
if op.type() == "conv2d" and (param not in op.all_inputs()):
|
|
for in_var in op.all_inputs():
|
|
if graph.is_parameter(in_var):
|
|
conv_param = in_var
|
|
self._prune_parameter_by_idx(
|
|
scope, [conv_param] + self._get_accumulator(
|
|
graph, conv_param),
|
|
corrected_idxs,
|
|
pruned_axis=1,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
if op.type() == "depthwise_conv2d":
|
|
for in_var in op.all_inputs():
|
|
if graph.is_parameter(in_var):
|
|
conv_param = in_var
|
|
self._prune_parameter_by_idx(
|
|
scope, [conv_param] + self._get_accumulator(
|
|
graph, conv_param),
|
|
corrected_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
elif op.type() == "elementwise_add":
|
|
# pruning bias
|
|
for in_var in op.all_inputs():
|
|
if graph.is_parameter(in_var):
|
|
bias_param = in_var
|
|
self._prune_parameter_by_idx(
|
|
scope, [bias_param] + self._get_accumulator(
|
|
graph, bias_param),
|
|
pruned_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
elif op.type() == "mul": # pruning fc layer
|
|
fc_input = None
|
|
fc_param = None
|
|
for in_var in op.all_inputs():
|
|
if graph.is_parameter(in_var):
|
|
fc_param = in_var
|
|
else:
|
|
fc_input = in_var
|
|
|
|
idx = []
|
|
feature_map_size = fc_input.shape()[2] * fc_input.shape()[3]
|
|
range_idx = np.array(range(feature_map_size))
|
|
for i in corrected_idxs:
|
|
idx += list(range_idx + i * feature_map_size)
|
|
corrected_idxs = idx
|
|
self._prune_parameter_by_idx(
|
|
scope, [fc_param] + self._get_accumulator(graph, fc_param),
|
|
corrected_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
|
|
elif op.type() == "concat":
|
|
concat_inputs = op.all_inputs()
|
|
last_op = related_ops[idx - 1]
|
|
for out_var in last_op.all_outputs():
|
|
if out_var in concat_inputs:
|
|
concat_idx = concat_inputs.index(out_var)
|
|
offset = 0
|
|
for ci in range(concat_idx):
|
|
offset += concat_inputs[ci].shape()[1]
|
|
corrected_idxs = [x + offset for x in pruned_idxs]
|
|
elif op.type() == "batch_norm":
|
|
bn_inputs = op.all_inputs()
|
|
mean = bn_inputs[2]
|
|
variance = bn_inputs[3]
|
|
alpha = bn_inputs[0]
|
|
beta = bn_inputs[1]
|
|
self._prune_parameter_by_idx(
|
|
scope, [mean] + self._get_accumulator(graph, mean),
|
|
corrected_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
self._prune_parameter_by_idx(
|
|
scope, [variance] + self._get_accumulator(graph, variance),
|
|
corrected_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
self._prune_parameter_by_idx(
|
|
scope, [alpha] + self._get_accumulator(graph, alpha),
|
|
corrected_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
self._prune_parameter_by_idx(
|
|
scope, [beta] + self._get_accumulator(graph, beta),
|
|
corrected_idxs,
|
|
pruned_axis=0,
|
|
place=place,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
|
|
def _prune_parameters(self,
|
|
graph,
|
|
scope,
|
|
params,
|
|
ratios,
|
|
place,
|
|
lazy=False,
|
|
only_graph=False,
|
|
param_backup=None,
|
|
param_shape_backup=None):
|
|
"""
|
|
Pruning the given parameters.
|
|
Args:
|
|
graph(GraphWrapper): The graph to be searched.
|
|
scope(fluid.core.Scope): The scope storing paramaters to be pruned.
|
|
params(list<str>): A list of parameter names to be pruned.
|
|
ratios(list<float>): A list of ratios to be used to pruning parameters.
|
|
place(fluid.Place): The device place of filter parameters.
|
|
pruned_idx(list): The index of elements to be pruned.
|
|
lazy(bool): True means setting the pruned elements to zero.
|
|
False means cutting down the pruned elements.
|
|
only_graph(bool): True means only modifying the graph.
|
|
False means modifying graph and variables in scope.
|
|
|
|
"""
|
|
_logger.debug('\n################################')
|
|
_logger.debug('# pruning parameters #')
|
|
_logger.debug('################################\n')
|
|
_logger.debug(
|
|
'|----------------------------------------+----+------------------------------+------------------------------|'
|
|
)
|
|
_logger.debug('|{:^40}|{:^4}|{:^30}|{:^30}|'.format('parameter', 'axis',
|
|
'from', 'to'))
|
|
assert len(params) == len(ratios)
|
|
self.pruned_list = [[], []]
|
|
for param, ratio in zip(params, ratios):
|
|
assert isinstance(param, str) or isinstance(param, unicode)
|
|
param = graph.var(param)
|
|
self._forward_pruning_ralated_params(
|
|
graph,
|
|
scope,
|
|
param,
|
|
place,
|
|
ratio=ratio,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
ops = param.outputs()
|
|
for op in ops:
|
|
if op.type() == 'conv2d':
|
|
brother_ops = self._search_brother_ops(graph, op)
|
|
for broher in brother_ops:
|
|
for p in graph.get_param_by_op(broher):
|
|
self._forward_pruning_ralated_params(
|
|
graph,
|
|
scope,
|
|
p,
|
|
place,
|
|
ratio=ratio,
|
|
lazy=lazy,
|
|
only_graph=only_graph,
|
|
param_backup=param_backup,
|
|
param_shape_backup=param_shape_backup)
|
|
_logger.debug(
|
|
'|----------------------------------------+----+------------------------------+------------------------------|'
|
|
)
|
|
|
|
def _search_brother_ops(self, graph, op_node):
|
|
"""
|
|
Search brother operators that was affected by pruning of given operator.
|
|
Args:
|
|
graph(GraphWrapper): The graph to be searched.
|
|
op_node(OpWrapper): The start node for searching.
|
|
Returns:
|
|
list<VarWrapper>: A list of operators.
|
|
"""
|
|
visited = [op_node.idx()]
|
|
stack = []
|
|
brothers = []
|
|
for op in graph.next_ops(op_node):
|
|
if (op.type() != 'conv2d') and (op.type() != 'fc') and (
|
|
not op._is_bwd_op()):
|
|
stack.append(op)
|
|
visited.append(op.idx())
|
|
while len(stack) > 0:
|
|
top_op = stack.pop()
|
|
for parent in graph.pre_ops(top_op):
|
|
if parent.idx() not in visited and (not parent._is_bwd_op()):
|
|
if ((parent.type == 'conv2d') or (parent.type == 'fc')):
|
|
brothers.append(parent)
|
|
else:
|
|
stack.append(parent)
|
|
visited.append(parent.idx())
|
|
|
|
for child in graph.next_ops(top_op):
|
|
if (child.type != 'conv2d') and (child.type != 'fc') and (
|
|
child.idx() not in visited) and (
|
|
not child._is_bwd_op()):
|
|
stack.append(child)
|
|
visited.append(child.idx())
|
|
return brothers
|
|
|
|
def _prune_graph(self, graph, target_graph):
|
|
"""
|
|
Pruning parameters of graph according to target graph.
|
|
Args:
|
|
graph(GraphWrapper): The graph to be pruned.
|
|
target_graph(GraphWrapper): The reference graph.
|
|
Return: None
|
|
"""
|
|
count = 1
|
|
_logger.debug(
|
|
'|----+----------------------------------------+------------------------------+------------------------------|'
|
|
)
|
|
_logger.debug('|{:^4}|{:^40}|{:^30}|{:^30}|'.format('id', 'parammeter',
|
|
'from', 'to'))
|
|
for param in target_graph.all_parameters():
|
|
var = graph.var(param.name())
|
|
ori_shape = var.shape()
|
|
var.set_shape(param.shape())
|
|
_logger.debug(
|
|
'|----+----------------------------------------+------------------------------+------------------------------|'
|
|
)
|
|
_logger.debug('|{:^4}|{:^40}|{:^30}|{:^30}|'.format(
|
|
str(count),
|
|
str(param.name()), str(ori_shape), str(param.shape())))
|
|
count += 1
|
|
_logger.debug(
|
|
'|----+----------------------------------------+------------------------------+------------------------------|'
|
|
)
|
|
|
|
|
|
class UniformPruneStrategy(PruneStrategy):
|
|
"""
|
|
The uniform pruning strategy. The parameters will be pruned by uniform ratio.
|
|
"""
|
|
|
|
def __init__(self,
|
|
pruner=None,
|
|
start_epoch=0,
|
|
end_epoch=0,
|
|
target_ratio=0.5,
|
|
metric_name=None,
|
|
pruned_params='conv.*_weights'):
|
|
"""
|
|
Args:
|
|
pruner(slim.Pruner): The pruner used to prune the parameters.
|
|
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
|
|
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
|
|
target_ratio(float): The flops ratio to be pruned from current model.
|
|
metric_name(str): The metric used to evaluate the model.
|
|
It should be one of keys in out_nodes of graph wrapper.
|
|
pruned_params(str): The pattern str to match the parameter names to be pruned.
|
|
"""
|
|
super(UniformPruneStrategy, self).__init__(pruner, start_epoch,
|
|
end_epoch, target_ratio,
|
|
metric_name, pruned_params)
|
|
|
|
def _get_best_ratios(self, context):
|
|
"""
|
|
Search a group of ratios for pruning target flops.
|
|
"""
|
|
_logger.info('_get_best_ratios')
|
|
pruned_params = []
|
|
for param in context.eval_graph.all_parameters():
|
|
if re.match(self.pruned_params, param.name()):
|
|
pruned_params.append(param.name())
|
|
|
|
min_ratio = 0.
|
|
max_ratio = 1.
|
|
|
|
flops = context.eval_graph.flops()
|
|
model_size = context.eval_graph.numel_params()
|
|
|
|
while min_ratio < max_ratio:
|
|
ratio = (max_ratio + min_ratio) / 2
|
|
_logger.debug(
|
|
'-----------Try pruning ratio: {:.2f}-----------'.format(ratio))
|
|
ratios = [ratio] * len(pruned_params)
|
|
param_shape_backup = {}
|
|
self._prune_parameters(
|
|
context.eval_graph,
|
|
context.scope,
|
|
pruned_params,
|
|
ratios,
|
|
context.place,
|
|
only_graph=True,
|
|
param_shape_backup=param_shape_backup)
|
|
|
|
pruned_flops = 1 - (float(context.eval_graph.flops()) / flops)
|
|
pruned_size = 1 - (float(context.eval_graph.numel_params()) /
|
|
model_size)
|
|
_logger.debug('Pruned flops: {:.2f}'.format(pruned_flops))
|
|
_logger.debug('Pruned model size: {:.2f}'.format(pruned_size))
|
|
for param in param_shape_backup.keys():
|
|
context.eval_graph.var(param).set_shape(param_shape_backup[
|
|
param])
|
|
|
|
if abs(pruned_flops - self.target_ratio) < 1e-2:
|
|
break
|
|
if pruned_flops > self.target_ratio:
|
|
max_ratio = ratio
|
|
else:
|
|
min_ratio = ratio
|
|
_logger.info('Get ratios: {}'.format([round(r, 2) for r in ratios]))
|
|
return pruned_params, ratios
|
|
|
|
def on_epoch_begin(self, context):
|
|
if context.epoch_id == self.start_epoch:
|
|
params, ratios = self._get_best_ratios(context)
|
|
|
|
self._prune_parameters(context.optimize_graph, context.scope,
|
|
params, ratios, context.place)
|
|
|
|
model_size = context.eval_graph.numel_params()
|
|
flops = context.eval_graph.flops()
|
|
_logger.debug('\n################################')
|
|
_logger.debug('# pruning eval graph #')
|
|
_logger.debug('################################\n')
|
|
self._prune_graph(context.eval_graph, context.optimize_graph)
|
|
context.optimize_graph.update_groups_of_conv()
|
|
context.eval_graph.update_groups_of_conv()
|
|
|
|
_logger.info(
|
|
'------------------finish pruning--------------------------------'
|
|
)
|
|
_logger.info('Pruned size: {:.2f}'.format(1 - (float(
|
|
context.eval_graph.numel_params()) / model_size)))
|
|
_logger.info('Pruned flops: {:.2f}'.format(1 - (float(
|
|
context.eval_graph.flops()) / flops)))
|
|
# metric = self._eval_graph(context)
|
|
# _logger.info('Metric after pruning: {:.2f}'.format(metric))
|
|
_logger.info(
|
|
'------------------UniformPruneStrategy.on_compression_begin finish--------------------------------'
|
|
)
|
|
|
|
|
|
class SensitivePruneStrategy(PruneStrategy):
|
|
"""
|
|
Sensitive pruning strategy. Different pruned ratio was applied on each layer.
|
|
"""
|
|
|
|
def __init__(self,
|
|
pruner=None,
|
|
start_epoch=0,
|
|
end_epoch=0,
|
|
delta_rate=0.20,
|
|
target_ratio=0.5,
|
|
metric_name='top1_acc',
|
|
pruned_params='conv.*_weights',
|
|
sensitivities_file='./sensitivities.data',
|
|
sensitivities={},
|
|
num_steps=1,
|
|
eval_rate=None):
|
|
"""
|
|
Args:
|
|
pruner(slim.Pruner): The pruner used to prune the parameters.
|
|
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0.
|
|
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 10.
|
|
delta_rate(float): The delta used to generate ratios when calculating sensitivities. default: 0.2
|
|
target_ratio(float): The flops ratio to be pruned from current model. default: 0.5
|
|
metric_name(str): The metric used to evaluate the model.
|
|
It should be one of keys in out_nodes of graph wrapper. default: 'top1_acc'
|
|
pruned_params(str): The pattern str to match the parameter names to be pruned. default: 'conv.*_weights'.
|
|
sensitivities_file(str): The sensitivities file. default: './sensitivities.data'
|
|
sensitivities(dict): The user-defined sensitivities. default: {}.
|
|
num_steps(int): The number of pruning steps. default: 1.
|
|
eval_rate(float): The rate of sampled data used to calculate sensitivities.
|
|
None means using all the data. default: None.
|
|
"""
|
|
super(SensitivePruneStrategy, self).__init__(pruner, start_epoch,
|
|
end_epoch, target_ratio,
|
|
metric_name, pruned_params)
|
|
self.delta_rate = delta_rate
|
|
self.pruned_list = []
|
|
self.sensitivities = sensitivities
|
|
self.sensitivities_file = sensitivities_file
|
|
self.num_steps = num_steps
|
|
self.eval_rate = eval_rate
|
|
self.pruning_step = 1 - pow((1 - target_ratio), 1.0 / self.num_steps)
|
|
|
|
def _save_sensitivities(self, sensitivities, sensitivities_file):
|
|
"""
|
|
Save sensitivities into file.
|
|
"""
|
|
with open(sensitivities_file, 'wb') as f:
|
|
pickle.dump(sensitivities, f)
|
|
|
|
def _load_sensitivities(self, sensitivities_file):
|
|
"""
|
|
Load sensitivities from file.
|
|
"""
|
|
sensitivities = {}
|
|
if sensitivities_file and os.path.exists(sensitivities_file):
|
|
with open(sensitivities_file, 'rb') as f:
|
|
if sys.version_info < (3, 0):
|
|
sensitivities = pickle.load(f)
|
|
else:
|
|
sensitivities = pickle.load(f, encoding='bytes')
|
|
|
|
for param in sensitivities:
|
|
sensitivities[param]['pruned_percent'] = [
|
|
round(p, 2) for p in sensitivities[param]['pruned_percent']
|
|
]
|
|
self._format_sensitivities(sensitivities)
|
|
return sensitivities
|
|
|
|
def _format_sensitivities(self, sensitivities):
|
|
"""
|
|
Print formated sensitivities in debug log level.
|
|
"""
|
|
tb = pt.PrettyTable()
|
|
tb.field_names = ["parameter", "size"] + [
|
|
str(round(i, 2))
|
|
for i in np.arange(self.delta_rate, 1, self.delta_rate)
|
|
]
|
|
for param in sensitivities:
|
|
if len(sensitivities[param]['loss']) == (len(tb.field_names) - 2):
|
|
tb.add_row([param, sensitivities[param]['size']] + [
|
|
round(loss, 2) for loss in sensitivities[param]['loss']
|
|
])
|
|
_logger.debug('\n################################')
|
|
_logger.debug('# sensitivities table #')
|
|
_logger.debug('################################\n')
|
|
_logger.debug(tb)
|
|
|
|
def _compute_sensitivities(self, context):
|
|
"""
|
|
Computing the sensitivities of all parameters.
|
|
"""
|
|
_logger.info("calling _compute_sensitivities.")
|
|
cached_id = np.random.randint(1000)
|
|
if self.start_epoch == context.epoch_id:
|
|
sensitivities_file = self.sensitivities_file
|
|
else:
|
|
sensitivities_file = self.sensitivities_file + ".epoch" + str(
|
|
context.epoch_id)
|
|
sensitivities = self._load_sensitivities(sensitivities_file)
|
|
|
|
for param in context.eval_graph.all_parameters():
|
|
if not re.match(self.pruned_params, param.name()):
|
|
continue
|
|
if param.name() not in sensitivities:
|
|
sensitivities[param.name()] = {
|
|
'pruned_percent': [],
|
|
'loss': [],
|
|
'size': param.shape()[0]
|
|
}
|
|
|
|
metric = None
|
|
|
|
for param in sensitivities.keys():
|
|
ratio = self.delta_rate
|
|
while ratio < 1:
|
|
ratio = round(ratio, 2)
|
|
if ratio in sensitivities[param]['pruned_percent']:
|
|
_logger.debug('{}, {} has computed.'.format(param, ratio))
|
|
ratio += self.delta_rate
|
|
continue
|
|
if metric is None:
|
|
metric = self._eval_graph(context, self.eval_rate,
|
|
cached_id)
|
|
|
|
param_backup = {}
|
|
# prune parameter by ratio
|
|
self._prune_parameters(
|
|
context.eval_graph,
|
|
context.scope, [param], [ratio],
|
|
context.place,
|
|
lazy=True,
|
|
param_backup=param_backup)
|
|
self.pruned_list[0]
|
|
# get accuracy after pruning and update self.sensitivities
|
|
pruned_metric = self._eval_graph(context, self.eval_rate,
|
|
cached_id)
|
|
loss = metric - pruned_metric
|
|
_logger.info("pruned param: {}; {}; loss={}".format(
|
|
param, ratio, loss))
|
|
for brother in self.pruned_list[0]:
|
|
if re.match(self.pruned_params, brother):
|
|
if brother not in sensitivities:
|
|
sensitivities[brother] = {
|
|
'pruned_percent': [],
|
|
'loss': []
|
|
}
|
|
sensitivities[brother]['pruned_percent'].append(ratio)
|
|
sensitivities[brother]['loss'].append(loss)
|
|
|
|
self._save_sensitivities(sensitivities, sensitivities_file)
|
|
|
|
# restore pruned parameters
|
|
for param_name in param_backup.keys():
|
|
param_t = context.scope.find_var(param_name).get_tensor()
|
|
param_t.set(self.param_backup[param_name], context.place)
|
|
|
|
# pruned_metric = self._eval_graph(context)
|
|
|
|
ratio += self.delta_rate
|
|
return sensitivities
|
|
|
|
def _get_best_ratios(self, context, sensitivities, target_ratio):
|
|
"""
|
|
Search a group of ratios for pruning target flops.
|
|
"""
|
|
_logger.info('_get_best_ratios for pruning ratie: {}'.format(
|
|
target_ratio))
|
|
|
|
def func(params, x):
|
|
a, b, c, d = params
|
|
return a * x * x * x + b * x * x + c * x + d
|
|
|
|
def error(params, x, y):
|
|
return func(params, x) - y
|
|
|
|
def slove_coefficient(x, y):
|
|
init_coefficient = [10, 10, 10, 10]
|
|
coefficient, loss = leastsq(error, init_coefficient, args=(x, y))
|
|
return coefficient
|
|
|
|
min_loss = 0.
|
|
max_loss = 0.
|
|
|
|
# step 1: fit curve by sensitivities
|
|
coefficients = {}
|
|
for param in sensitivities:
|
|
losses = np.array([0] * 5 + sensitivities[param]['loss'])
|
|
precents = np.array([0] * 5 + sensitivities[param][
|
|
'pruned_percent'])
|
|
coefficients[param] = slove_coefficient(precents, losses)
|
|
loss = np.max(losses)
|
|
max_loss = np.max([max_loss, loss])
|
|
|
|
# step 2: Find a group of ratios by binary searching.
|
|
flops = context.eval_graph.flops()
|
|
model_size = context.eval_graph.numel_params()
|
|
ratios = []
|
|
while min_loss < max_loss:
|
|
loss = (max_loss + min_loss) / 2
|
|
_logger.info(
|
|
'-----------Try pruned ratios while acc loss={:.4f}-----------'.
|
|
format(loss))
|
|
ratios = []
|
|
# step 2.1: Get ratios according to current loss
|
|
for param in sensitivities:
|
|
coefficient = copy.deepcopy(coefficients[param])
|
|
coefficient[-1] = coefficient[-1] - loss
|
|
roots = np.roots(coefficient)
|
|
for root in roots:
|
|
min_root = 1
|
|
if np.isreal(root) and root > 0 and root < 1:
|
|
selected_root = min(root.real, min_root)
|
|
ratios.append(selected_root)
|
|
_logger.info('Pruned ratios={}'.format(
|
|
[round(ratio, 3) for ratio in ratios]))
|
|
# step 2.2: Pruning by current ratios
|
|
param_shape_backup = {}
|
|
self._prune_parameters(
|
|
context.eval_graph,
|
|
context.scope,
|
|
sensitivities.keys(),
|
|
ratios,
|
|
context.place,
|
|
only_graph=True,
|
|
param_shape_backup=param_shape_backup)
|
|
|
|
pruned_flops = 1 - (float(context.eval_graph.flops()) / flops)
|
|
pruned_size = 1 - (float(context.eval_graph.numel_params()) /
|
|
model_size)
|
|
_logger.info('Pruned flops: {:.4f}'.format(pruned_flops))
|
|
_logger.info('Pruned model size: {:.4f}'.format(pruned_size))
|
|
for param in param_shape_backup.keys():
|
|
context.eval_graph.var(param).set_shape(param_shape_backup[
|
|
param])
|
|
|
|
# step 2.3: Check whether current ratios is enough
|
|
if abs(pruned_flops - target_ratio) < 0.015:
|
|
break
|
|
if pruned_flops > target_ratio:
|
|
max_loss = loss
|
|
else:
|
|
min_loss = loss
|
|
return sensitivities.keys(), ratios
|
|
|
|
def _current_pruning_target(self, context):
|
|
'''
|
|
Get the target pruning rate in current epoch.
|
|
'''
|
|
_logger.info('Left number of pruning steps: {}'.format(self.num_steps))
|
|
if self.num_steps <= 0:
|
|
return None
|
|
if (self.start_epoch == context.epoch_id) or context.eval_converged(
|
|
self.metric_name, 0.005):
|
|
self.num_steps -= 1
|
|
return self.pruning_step
|
|
|
|
def on_epoch_begin(self, context):
|
|
current_ratio = self._current_pruning_target(context)
|
|
if current_ratio is not None:
|
|
sensitivities = self._compute_sensitivities(context)
|
|
params, ratios = self._get_best_ratios(context, sensitivities,
|
|
current_ratio)
|
|
self._prune_parameters(context.optimize_graph, context.scope,
|
|
params, ratios, context.place)
|
|
|
|
model_size = context.eval_graph.numel_params()
|
|
flops = context.eval_graph.flops()
|
|
_logger.debug('################################')
|
|
_logger.debug('# pruning eval graph #')
|
|
_logger.debug('################################')
|
|
self._prune_graph(context.eval_graph, context.optimize_graph)
|
|
context.optimize_graph.update_groups_of_conv()
|
|
context.eval_graph.update_groups_of_conv()
|
|
context.optimize_graph.compile() # to update the compiled program
|
|
context.eval_graph.compile(
|
|
for_parallel=False,
|
|
for_test=True) # to update the compiled program
|
|
_logger.info(
|
|
'------------------finish pruning--------------------------------'
|
|
)
|
|
_logger.info('Pruned size: {:.3f}'.format(1 - (float(
|
|
context.eval_graph.numel_params()) / model_size)))
|
|
_logger.info('Pruned flops: {:.3f}'.format(1 - (float(
|
|
context.eval_graph.flops()) / flops)))
|
|
metric = self._eval_graph(context)
|
|
_logger.info('Metric after pruning: {:.2f}'.format(metric))
|
|
_logger.info(
|
|
'------------------SensitivePruneStrategy.on_epoch_begin finish--------------------------------'
|
|
)
|