[PaddleSlim] Enhence compressor api in PaddleSlim (#19894)
1. Support customize eval function instead of eval program. 2. Fix loading checkpoint in quantization strategy. 3. Support saving eval model when saving a checkpoint. 4. Fix decoder of loading context in PaddleSlim. 5. Fix restoring from the checkpoint of uniform prune strategy. 6. Support saving eval model and infer model during training. 7. Add ‘unitest’ for saving eval model, saving infer model and uniform pruning restoring from the checkpoint. 8. Fix pruning of depthwise_conv_grad op by updating the groups.expand_as_op_1
parent
cedc04775c
commit
bdb3e376d0
@ -0,0 +1,4 @@
|
||||
version: 1.0
|
||||
compressor:
|
||||
epoch: 1
|
||||
checkpoint_path: './checkpoints/'
|
@ -0,0 +1,21 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'StructurePruner'
|
||||
pruning_axis:
|
||||
'*': 0
|
||||
criterions:
|
||||
'*': 'l1_norm'
|
||||
strategies:
|
||||
uniform_pruning_strategy:
|
||||
class: 'UniformPruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
start_epoch: 0
|
||||
target_ratio: 0.5
|
||||
pruned_params: 'conv.*'
|
||||
metric_name: 'acc_top1'
|
||||
compressor:
|
||||
epoch: 2
|
||||
checkpoint_path: './checkpoints_uniform_restore_tmp/'
|
||||
strategies:
|
||||
- uniform_pruning_strategy
|
@ -0,0 +1,21 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'StructurePruner'
|
||||
pruning_axis:
|
||||
'*': 0
|
||||
criterions:
|
||||
'*': 'l1_norm'
|
||||
strategies:
|
||||
uniform_pruning_strategy:
|
||||
class: 'UniformPruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
start_epoch: 0
|
||||
target_ratio: 0.5
|
||||
pruned_params: 'conv.*'
|
||||
metric_name: 'acc_top1'
|
||||
compressor:
|
||||
epoch: 1
|
||||
checkpoint_path: './checkpoints_uniform_restore/'
|
||||
strategies:
|
||||
- uniform_pruning_strategy
|
@ -0,0 +1,21 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'StructurePruner'
|
||||
pruning_axis:
|
||||
'*': 0
|
||||
criterions:
|
||||
'*': 'l1_norm'
|
||||
strategies:
|
||||
uniform_pruning_strategy:
|
||||
class: 'UniformPruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
start_epoch: 0
|
||||
target_ratio: 0.5
|
||||
pruned_params: 'conv.*'
|
||||
metric_name: 'acc_top1'
|
||||
compressor:
|
||||
epoch: 2
|
||||
checkpoint_path: './checkpoints_uniform_restore/'
|
||||
strategies:
|
||||
- uniform_pruning_strategy
|
@ -0,0 +1,50 @@
|
||||
#start_epoch(int): The epoch to insert quantization operators. default: 0
|
||||
#
|
||||
#end_epoch(int): The epoch to save inference model. default: 0
|
||||
#
|
||||
#float_model_save_path(str): The path to save model with float weights.
|
||||
# None means it doesn't save float model. default: None.
|
||||
#
|
||||
#mobile_model_save_path(str): The path to save model for paddle-mobile execution.
|
||||
# None means it doesn't save mobile model. default: None.
|
||||
#
|
||||
#int8_model_save_path(str): The path to save model with int8_t weight.
|
||||
# None means it doesn't save int8 model. default: None.
|
||||
#
|
||||
#activation_bits(int): quantization bit number for activation. default: 8.
|
||||
#
|
||||
#weight_bits(int): quantization bit number for weights. The bias is not quantized.
|
||||
# default: 8.
|
||||
#
|
||||
#activation_quantize_type(str): quantization type for activation,
|
||||
# now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
|
||||
# If use 'abs_max' mode, the quantization scale will be calculated
|
||||
# dynamically each step in both training and testing period. If use
|
||||
# 'range_abs_max', a static quantization scale will be calculated
|
||||
# during training and used in inference.
|
||||
#
|
||||
#save_in_nodes(list<str>): A list of variable names used to prune graph
|
||||
# for saving inference model.
|
||||
#
|
||||
#save_out_nodes(list<str>): A list of variable names used to prune graph
|
||||
# for saving inference model.
|
||||
version: 1.0
|
||||
strategies:
|
||||
quantization_strategy:
|
||||
class: 'QuantizationStrategy'
|
||||
start_epoch: 0
|
||||
end_epoch: 0
|
||||
float_model_save_path: './output/float'
|
||||
mobile_model_save_path: './output/mobile'
|
||||
int8_model_save_path: './output/int8'
|
||||
weight_bits: 8
|
||||
activation_bits: 8
|
||||
weight_quantize_type: 'abs_max'
|
||||
activation_quantize_type: 'abs_max'
|
||||
save_in_nodes: ['image']
|
||||
save_out_nodes: ['quan.tmp_2']
|
||||
compressor:
|
||||
epoch: 2
|
||||
checkpoint_path: './checkpoints_quan/'
|
||||
strategies:
|
||||
- quantization_strategy
|
@ -0,0 +1,99 @@
|
||||
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||
#
|
||||
# licensed under the apache license, version 2.0 (the "license");
|
||||
# you may not use this file except in compliance with the license.
|
||||
# you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the license is distributed on an "as is" basis,
|
||||
# without warranties or conditions of any kind, either express or implied.
|
||||
# see the license for the specific language governing permissions and
|
||||
# limitations under the license.
|
||||
|
||||
import paddle
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.contrib.slim.core import Compressor
|
||||
from paddle.fluid.contrib.slim.graph import GraphWrapper
|
||||
|
||||
|
||||
class TestCompressor(unittest.TestCase):
|
||||
def test_eval_func(self):
|
||||
class_dim = 10
|
||||
image_shape = [1, 28, 28]
|
||||
image = fluid.layers.data(
|
||||
name='image', shape=image_shape, dtype='float32')
|
||||
image.stop_gradient = False
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
out = fluid.layers.fc(input=image, size=class_dim)
|
||||
out = fluid.layers.softmax(out)
|
||||
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
|
||||
val_program = fluid.default_main_program().clone(for_test=False)
|
||||
|
||||
cost = fluid.layers.cross_entropy(input=out, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
optimizer = fluid.optimizer.Momentum(
|
||||
momentum=0.9,
|
||||
learning_rate=0.01,
|
||||
regularization=fluid.regularizer.L2Decay(4e-5))
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.mnist.train(), batch_size=128)
|
||||
train_feed_list = [('img', image.name), ('label', label.name)]
|
||||
train_fetch_list = [('loss', avg_cost.name)]
|
||||
eval_feed_list = [('img', image.name), ('label', label.name)]
|
||||
eval_fetch_list = [('acc_top1', acc_top1.name)]
|
||||
|
||||
def eval_func(program, scope):
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[image.name, label.name],
|
||||
place=place,
|
||||
program=program)
|
||||
results = []
|
||||
for data in val_reader():
|
||||
result = exe.run(program=program,
|
||||
scope=scope,
|
||||
fetch_list=[acc_top1.name],
|
||||
feed=feeder.feed(data))
|
||||
results.append(np.array(result))
|
||||
result = np.mean(results)
|
||||
return result
|
||||
|
||||
com_pass = Compressor(
|
||||
place,
|
||||
fluid.global_scope(),
|
||||
fluid.default_main_program(),
|
||||
train_reader=train_reader,
|
||||
train_feed_list=train_feed_list,
|
||||
train_fetch_list=train_fetch_list,
|
||||
eval_program=val_program,
|
||||
eval_feed_list=eval_feed_list,
|
||||
eval_fetch_list=eval_fetch_list,
|
||||
eval_func={"score": eval_func},
|
||||
prune_infer_model=[[image.name], [out.name]],
|
||||
train_optimizer=optimizer)
|
||||
com_pass.config('./configs/compress.yaml')
|
||||
com_pass.run()
|
||||
self.assertTrue('score' in com_pass.context.eval_results)
|
||||
self.assertTrue(float(com_pass.context.eval_results['score'][0]) > 0.9)
|
||||
self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__model__"))
|
||||
self.assertTrue(
|
||||
os.path.exists("./checkpoints/0/eval_model/__model__.infer"))
|
||||
self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__params__"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue