|
|
|
@ -39,7 +39,9 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
max_ratio=0.7,
|
|
|
|
|
metric_name='top1_acc',
|
|
|
|
|
pruned_params='conv.*_weights',
|
|
|
|
|
retrain_epoch=0):
|
|
|
|
|
retrain_epoch=0,
|
|
|
|
|
uniform_range=None,
|
|
|
|
|
init_tokens=None):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
pruner(slim.Pruner): The pruner used to prune the parameters. Default: None.
|
|
|
|
@ -52,6 +54,8 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
|
|
|
|
|
pruned_params(str): The pattern str to match the parameter names to be pruned. Default: 'conv.*_weights'
|
|
|
|
|
retrain_epoch(int): The training epochs in each seaching step. Default: 0
|
|
|
|
|
uniform_range(int): The token range in each position of tokens generated by controller. None means getting the range automatically. Default: None.
|
|
|
|
|
init_tokens(list<int>): The initial tokens. None means getting the initial tokens automatically. Default: None.
|
|
|
|
|
"""
|
|
|
|
|
super(AutoPruneStrategy, self).__init__(pruner, start_epoch, end_epoch,
|
|
|
|
|
0.0, metric_name, pruned_params)
|
|
|
|
@ -60,8 +64,9 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
self._controller = controller
|
|
|
|
|
self._metric_name = metric_name
|
|
|
|
|
self._pruned_param_names = []
|
|
|
|
|
self._retrain_epoch = 0
|
|
|
|
|
|
|
|
|
|
self._retrain_epoch = retrain_epoch
|
|
|
|
|
self._uniform_range = uniform_range
|
|
|
|
|
self._init_tokens = init_tokens
|
|
|
|
|
self._current_tokens = None
|
|
|
|
|
|
|
|
|
|
def on_compression_begin(self, context):
|
|
|
|
@ -75,9 +80,18 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
if re.match(self.pruned_params, param.name()):
|
|
|
|
|
self._pruned_param_names.append(param.name())
|
|
|
|
|
|
|
|
|
|
self._current_tokens = self._get_init_tokens(context)
|
|
|
|
|
self._range_table = copy.deepcopy(self._current_tokens)
|
|
|
|
|
if self._init_tokens is not None:
|
|
|
|
|
self._current_tokens = self._init_tokens
|
|
|
|
|
else:
|
|
|
|
|
self._current_tokens = self._get_init_tokens(context)
|
|
|
|
|
|
|
|
|
|
if self._uniform_range is not None:
|
|
|
|
|
self._range_table = [round(self._uniform_range, 2) / 0.01] * len(
|
|
|
|
|
self._pruned_param_names)
|
|
|
|
|
else:
|
|
|
|
|
self._range_table = copy.deepcopy(self._current_tokens)
|
|
|
|
|
_logger.info('init tokens: {}'.format(self._current_tokens))
|
|
|
|
|
_logger.info("range_table: {}".format(self._range_table))
|
|
|
|
|
constrain_func = functools.partial(
|
|
|
|
|
self._constrain_func, context=context)
|
|
|
|
|
|
|
|
|
@ -104,14 +118,20 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
context.eval_graph.var(param).set_shape(param_shape_backup[param])
|
|
|
|
|
flops_ratio = (1 - float(flops) / ori_flops)
|
|
|
|
|
if flops_ratio >= self._min_ratio and flops_ratio <= self._max_ratio:
|
|
|
|
|
_logger.info("Success try [{}]; flops: -{}".format(tokens,
|
|
|
|
|
flops_ratio))
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
_logger.info("Failed try [{}]; flops: -{}".format(tokens,
|
|
|
|
|
flops_ratio))
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _get_init_tokens(self, context):
|
|
|
|
|
"""Get initial tokens.
|
|
|
|
|
"""
|
|
|
|
|
ratios = self._get_uniform_ratios(context)
|
|
|
|
|
_logger.info('Get init ratios: {}'.format(
|
|
|
|
|
[round(r, 2) for r in ratios]))
|
|
|
|
|
return self._ratios_to_tokens(ratios)
|
|
|
|
|
|
|
|
|
|
def _ratios_to_tokens(self, ratios):
|
|
|
|
@ -171,7 +191,7 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
|
|
|
|
|
self._retrain_epoch == 0 or
|
|
|
|
|
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
|
|
|
|
|
self._current_tokens = self._controller.next_tokens()
|
|
|
|
|
_logger.info("on_epoch_begin")
|
|
|
|
|
params = self._pruned_param_names
|
|
|
|
|
ratios = self._tokens_to_ratios(self._current_tokens)
|
|
|
|
|
|
|
|
|
@ -189,7 +209,7 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
context.optimize_graph.update_groups_of_conv()
|
|
|
|
|
context.eval_graph.update_groups_of_conv()
|
|
|
|
|
context.optimize_graph.compile(
|
|
|
|
|
mem_opt=True) # to update the compiled program
|
|
|
|
|
mem_opt=False) # to update the compiled program
|
|
|
|
|
context.skip_training = (self._retrain_epoch == 0)
|
|
|
|
|
|
|
|
|
|
def on_epoch_end(self, context):
|
|
|
|
@ -199,10 +219,13 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
"""
|
|
|
|
|
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
|
|
|
|
|
self._retrain_epoch == 0 or
|
|
|
|
|
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
|
|
|
|
|
(context.epoch_id - self.start_epoch + 1
|
|
|
|
|
) % self._retrain_epoch == 0):
|
|
|
|
|
_logger.info("on_epoch_end")
|
|
|
|
|
reward = context.eval_results[self._metric_name][-1]
|
|
|
|
|
self._controller.update(self._current_tokens, reward)
|
|
|
|
|
|
|
|
|
|
self._current_tokens = self._controller.next_tokens()
|
|
|
|
|
# restore pruned parameters
|
|
|
|
|
for param_name in self._param_backup.keys():
|
|
|
|
|
param_t = context.scope.find_var(param_name).get_tensor()
|
|
|
|
@ -218,7 +241,7 @@ class AutoPruneStrategy(PruneStrategy):
|
|
|
|
|
context.optimize_graph.update_groups_of_conv()
|
|
|
|
|
context.eval_graph.update_groups_of_conv()
|
|
|
|
|
context.optimize_graph.compile(
|
|
|
|
|
mem_opt=True) # to update the compiled program
|
|
|
|
|
mem_opt=False) # to update the compiled program
|
|
|
|
|
|
|
|
|
|
elif context.epoch_id == self.end_epoch: # restore graph for final training
|
|
|
|
|
# restore pruned parameters
|
|
|
|
|