Init paddle slim (#14834)
* Init slim. * Remove distillation demo. * Fix import errors. test=develop * Fix some issues. test=develop * Fix configs. test=develop * Modify API.spec. test=develop * Fix format. test=develop * Fix format. test=develop * Add some comments.revert-15207-remove_op_handle_lock_and_fix_var
parent
00dadb0720
commit
938705745e
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .core import *
|
||||
from .graph import *
|
||||
from .prune import *
|
||||
__all__ = [
|
||||
'build_compressor',
|
||||
'CompressPass',
|
||||
'ImitationGraph',
|
||||
'SensitivePruneStrategy',
|
||||
'MagnitudePruner',
|
||||
'RatioPruner',
|
||||
]
|
@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import config
|
||||
from .config import *
|
||||
from . import compress_pass
|
||||
from .compress_pass import *
|
||||
from . import strategy
|
||||
from .strategy import *
|
||||
from . import pass_builder
|
||||
from .pass_builder import *
|
||||
|
||||
__all__ = config.__all__ + compress_pass.__all__ + strategy.__all__ + pass_builder.__all__
|
@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ....core import CPUPlace
|
||||
from ..graph import get_executor
|
||||
|
||||
__all__ = ['Context', 'CompressPass']
|
||||
|
||||
|
||||
class Context(object):
|
||||
"""
|
||||
The context in the process of compression.
|
||||
Args:
|
||||
exe: The executor used to execute graph.
|
||||
graph: The graph to be compressed.
|
||||
scope: The scope used to execute graph.
|
||||
program_exe: The program_exe is used to execute the program
|
||||
created for modifying the variables in scope.
|
||||
"""
|
||||
|
||||
def __init__(self, exe, graph, scope, program_exe=None):
|
||||
# The total number of epoches to be trained.
|
||||
self.epoch = 0
|
||||
# Current epoch
|
||||
self.epoch_id = 0
|
||||
# Current batch
|
||||
self.batch_id = 0
|
||||
self.exe = exe
|
||||
self.graph = graph
|
||||
self.scope = scope
|
||||
self.program_exe = program_exe
|
||||
|
||||
|
||||
class CompressPass(object):
|
||||
"""
|
||||
The pass used to compress model.
|
||||
Args:
|
||||
place: The device used in compression.
|
||||
data_reader: The data_reader used to run graph.
|
||||
data_feeder: The data_feeder used to run graph.
|
||||
scope: The scope used to run graph.
|
||||
metrics: The metrics for evaluating model.
|
||||
epoch: The total epoches of trainning in compression.
|
||||
program_exe: The program_exe is used to execute the program
|
||||
created for modifying the variables in scope.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
place=None,
|
||||
data_reader=None,
|
||||
data_feeder=None,
|
||||
scope=None,
|
||||
metrics=None,
|
||||
epoch=None,
|
||||
program_exe=None):
|
||||
self.strategies = []
|
||||
self.place = CPUPlace() if place is None else place
|
||||
self.data_reader = data_reader
|
||||
self.data_feeder = data_feeder
|
||||
self.scope = scope
|
||||
self.metrics = metrics
|
||||
self.epoch = epoch
|
||||
self.program_exe = program_exe
|
||||
|
||||
def add_strategy(self, strategy):
|
||||
"""
|
||||
Add a strategy to current compress pass.
|
||||
Args:
|
||||
strategy: The strategy to be added into current compress pass.
|
||||
"""
|
||||
self.strategies.append(strategy)
|
||||
self.epoch = max(strategy.end_epoch, self.epoch)
|
||||
|
||||
def apply(self, graph):
|
||||
"""
|
||||
Compress a model.
|
||||
Args:
|
||||
graph: The target graph to be compressed.
|
||||
"""
|
||||
self.executor = get_executor(graph, self.place)
|
||||
context = Context(
|
||||
self.executor, graph, self.scope, program_exe=self.program_exe)
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_compress_begin(context)
|
||||
|
||||
for epoch in range(self.epoch):
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_epoch_begin(context)
|
||||
|
||||
for data in self.data_reader():
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_batch_begin(context)
|
||||
fetches = None
|
||||
if self.metrics:
|
||||
fetches = self.metrics.values()
|
||||
feed = None
|
||||
if self.data_feeder:
|
||||
feed = self.data_feeder.feed(data)
|
||||
results = self.executor.run(graph,
|
||||
fetches=fetches,
|
||||
scope=self.scope,
|
||||
feed=feed)
|
||||
if results:
|
||||
print("results: {}".format(
|
||||
zip(self.metrics.keys(), results)))
|
||||
for strategy in self.strategies:
|
||||
strategy.on_batch_end(context)
|
||||
context.batch_id += 1
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_epoch_end(context)
|
||||
context.epoch_id += 1
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_compress_end(context)
|
@ -0,0 +1,111 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import funcsigs
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
from ..prune import *
|
||||
from .compress_pass import *
|
||||
from .strategy import *
|
||||
|
||||
__all__ = ['ConfigFactory']
|
||||
"""This factory is used to create instances by loading and parsing configure file with yaml format.
|
||||
"""
|
||||
|
||||
|
||||
class ConfigFactory(object):
|
||||
def __init__(self, config):
|
||||
"""Init a factory from configure file."""
|
||||
self.instances = {}
|
||||
self.version = None
|
||||
self._parse_config(config)
|
||||
|
||||
def get_compress_pass(self):
|
||||
"""
|
||||
Get compress pass from factory.
|
||||
"""
|
||||
return self.instance('compress_pass')
|
||||
|
||||
def instance(self, name):
|
||||
"""
|
||||
Get instance from factory.
|
||||
"""
|
||||
if name in self.instances:
|
||||
return self.instances[name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _new_instance(self, name, attrs):
|
||||
if name not in self.instances:
|
||||
class_ = globals()[attrs['class']]
|
||||
sig = funcsigs.signature(class_.__init__)
|
||||
keys = [
|
||||
param.name for param in sig.parameters.values()
|
||||
if (param.kind == param.POSITIONAL_OR_KEYWORD)
|
||||
][1:]
|
||||
keys = set(attrs.keys()).intersection(set(keys))
|
||||
args = {}
|
||||
for key in keys:
|
||||
value = attrs[key]
|
||||
if isinstance(value, str) and value in self.instances:
|
||||
value = self.instances[value]
|
||||
args[key] = value
|
||||
self.instances[name] = class_(**args)
|
||||
return self.instances.get(name)
|
||||
|
||||
def _parse_config(self, config):
|
||||
assert config
|
||||
with open(config, 'r') as config_file:
|
||||
key_values = self._ordered_load(config_file)
|
||||
for key in key_values:
|
||||
# parse version
|
||||
if key == 'version' and self.version is None:
|
||||
self.version = int(key_values['version'])
|
||||
assert self.version == int(key_values['version'])
|
||||
|
||||
# parse pruners
|
||||
if key == 'pruners' or key == 'strategies':
|
||||
instances = key_values[key]
|
||||
for name in instances:
|
||||
self._new_instance(name, instances[name])
|
||||
|
||||
if key == 'compress_pass':
|
||||
compress_pass = self._new_instance(key, key_values[key])
|
||||
for name in key_values[key]['strategies']:
|
||||
strategy = self.instance(name)
|
||||
compress_pass.add_strategy(strategy)
|
||||
|
||||
if key == 'include':
|
||||
for config_file in key_values[key]:
|
||||
self._parse_config(config_file.strip())
|
||||
|
||||
def _ordered_load(self,
|
||||
stream,
|
||||
Loader=yaml.Loader,
|
||||
object_pairs_hook=OrderedDict):
|
||||
"""
|
||||
See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts
|
||||
"""
|
||||
|
||||
class OrderedLoader(Loader):
|
||||
pass
|
||||
|
||||
def construct_mapping(loader, node):
|
||||
loader.flatten_mapping(node)
|
||||
return object_pairs_hook(loader.construct_pairs(node))
|
||||
|
||||
OrderedLoader.add_constructor(
|
||||
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
|
||||
return yaml.load(stream, OrderedLoader)
|
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .compress_pass import CompressPass
|
||||
from .config import ConfigFactory
|
||||
|
||||
__all__ = ['build_compressor']
|
||||
|
||||
|
||||
def build_compressor(place=None,
|
||||
data_reader=None,
|
||||
data_feeder=None,
|
||||
scope=None,
|
||||
metrics=None,
|
||||
epoch=None,
|
||||
config=None):
|
||||
if config is not None:
|
||||
factory = ConfigFactory(config)
|
||||
comp_pass = factory.get_compress_pass()
|
||||
else:
|
||||
comp_pass = CompressPass()
|
||||
comp_pass.place = place
|
||||
comp_pass.data_reader = data_reader
|
||||
comp_pass.data_feeder = data_feeder
|
||||
comp_pass.scope = scope
|
||||
comp_pass.metrics = metrics
|
||||
comp_pass.epoch = epoch
|
||||
return comp_pass
|
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['Strategy']
|
||||
|
||||
|
||||
class Strategy(object):
|
||||
"""
|
||||
Base class for all strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, start_epoch=0, end_epoch=10):
|
||||
"""
|
||||
Args:
|
||||
start_epoch: The first epoch to apply the strategy.
|
||||
end_epoch: The last epoch to apply the strategy.
|
||||
"""
|
||||
self.start_epoch = start_epoch
|
||||
self.end_epoch = end_epoch
|
||||
|
||||
def on_compress_begin(self, context):
|
||||
pass
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
pass
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
pass
|
||||
|
||||
def on_batch_begin(self, context):
|
||||
pass
|
||||
|
||||
def on_batch_end(self, context):
|
||||
pass
|
||||
|
||||
def on_compress_end(self, context):
|
||||
pass
|
@ -0,0 +1,28 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.3
|
||||
'conv1_2.w': 0.4
|
||||
'*': 0.9
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
||||
strategies:
|
||||
strategy_1:
|
||||
class: 'SensitivePruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
start_epoch: 0
|
||||
end_epoch: 10
|
||||
delta_rate: 0.20
|
||||
acc_loss_threshold: 0.2
|
||||
sensitivities:
|
||||
'conv1_1.w': 0.4
|
||||
|
||||
compress_pass:
|
||||
class: 'CompressPass'
|
||||
epoch: 100
|
||||
strategies:
|
||||
- strategy_1
|
@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
import os
|
||||
import sys
|
||||
from paddle.fluid.contrib.slim import CompressPass
|
||||
from paddle.fluid.contrib.slim import build_compressor
|
||||
from paddle.fluid.contrib.slim import ImitationGraph
|
||||
|
||||
|
||||
class LinearModel(object):
|
||||
def __init__(slef):
|
||||
pass
|
||||
|
||||
def train(self):
|
||||
train_program = fluid.Program()
|
||||
startup_program = fluid.Program()
|
||||
startup_program.random_seed = 10
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
|
||||
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
|
||||
predict = fluid.layers.fc(input=x, size=1, act=None)
|
||||
cost = fluid.layers.square_error_cost(input=predict, label=y)
|
||||
avg_cost = fluid.layers.mean(cost)
|
||||
eval_program = train_program.clone()
|
||||
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
|
||||
sgd_optimizer.minimize(avg_cost)
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.uci_housing.train(), batch_size=1)
|
||||
eval_reader = paddle.batch(
|
||||
paddle.dataset.uci_housing.test(), batch_size=1)
|
||||
place = fluid.CPUPlace()
|
||||
train_feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
|
||||
eval_feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_program)
|
||||
train_metrics = {"loss": avg_cost.name}
|
||||
eval_metrics = {"loss": avg_cost.name}
|
||||
|
||||
graph = ImitationGraph(train_program)
|
||||
config = './config.yaml'
|
||||
comp_pass = build_compressor(
|
||||
place,
|
||||
data_reader=train_reader,
|
||||
data_feeder=train_feeder,
|
||||
scope=fluid.global_scope(),
|
||||
metrics=train_metrics,
|
||||
epoch=1,
|
||||
config=config)
|
||||
comp_pass.apply(graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = LinearModel()
|
||||
model.train()
|
@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import executor
|
||||
from .executor import *
|
||||
from . import graph
|
||||
from .graph import *
|
||||
from . import graph_pass
|
||||
from .graph_pass import *
|
||||
__all__ = executor.__all__
|
||||
__all__ += graph.__all__
|
||||
__all__ += graph_pass.__all__
|
@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from abc import abstractmethod
|
||||
from .... import executor
|
||||
from .graph import IRGraph, ImitationGraph
|
||||
|
||||
__all__ = ['get_executor']
|
||||
|
||||
|
||||
class GraphExecutor(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, place):
|
||||
self.place = place
|
||||
|
||||
@abstractmethod
|
||||
def run(self, graph, feches=None, feed=None):
|
||||
pass
|
||||
|
||||
|
||||
class IRGraphExecutor(GraphExecutor):
|
||||
def run(self, grah, fetches, feed=None):
|
||||
pass
|
||||
|
||||
|
||||
class ImitationGraphExecutor(GraphExecutor):
|
||||
def __init__(self, place):
|
||||
super(ImitationGraphExecutor, self).__init__(place)
|
||||
self.exe = executor.Executor(place)
|
||||
|
||||
def run(self, graph, scope=None, fetches=None, feed=None):
|
||||
assert isinstance(graph, ImitationGraph)
|
||||
fetch_list = None
|
||||
if fetches:
|
||||
fetch_list = [
|
||||
graph.program.global_block().var(name) for name in fetches
|
||||
]
|
||||
results = self.exe.run(graph.program,
|
||||
scope=scope,
|
||||
fetch_list=fetch_list,
|
||||
feed=feed)
|
||||
return results
|
||||
|
||||
|
||||
def get_executor(graph, place):
|
||||
if isinstance(graph, ImitationGraph):
|
||||
return ImitationGraphExecutor(place)
|
||||
if isinstance(graph, IRGraph):
|
||||
return IRGraphExecutor(place)
|
@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ....framework import Program
|
||||
|
||||
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']
|
||||
|
||||
|
||||
class Graph(object):
|
||||
"""
|
||||
Base class for all graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def all_parameters(self):
|
||||
"""
|
||||
Return all the parameters in current graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ImitationGraph(Graph):
|
||||
def __init__(self, program=None):
|
||||
super(ImitationGraph, self).__init__()
|
||||
self.program = Program() if program is None else program
|
||||
|
||||
def all_parameters(self):
|
||||
return self.program.global_block().all_parameters()
|
||||
|
||||
|
||||
class IRGraph(Graph):
|
||||
pass
|
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['GraphPass', 'PruneParameterPass']
|
||||
|
||||
|
||||
class GraphPass(object):
|
||||
"""
|
||||
Base class for all graph pass.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def apply(self, graph):
|
||||
pass
|
||||
|
||||
|
||||
class PruneParameterPass(GraphPass):
|
||||
"""
|
||||
Generate a graph for pruning parameters from target graph.
|
||||
"""
|
||||
|
||||
def __init__(self, pruned_params, thresholds):
|
||||
super(PruneParameterPass, self).__init__()
|
||||
self.pruned_params = pruned_params
|
||||
self.thresholds = thresholds
|
||||
self.default_threshold = thresholds['*']
|
||||
|
||||
def apply(self, graph):
|
||||
pass
|
@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import pruner
|
||||
from .pruner import *
|
||||
from . import prune_strategy
|
||||
from .prune_strategy import *
|
||||
|
||||
__all__ = pruner.__all__
|
||||
__all__ += prune_strategy.__all__
|
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..core.strategy import Strategy
|
||||
from ....framework import Program, program_guard
|
||||
from .... import layers
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['SensitivePruneStrategy', 'PruneStrategy']
|
||||
|
||||
|
||||
class SensitivePruneStrategy(Strategy):
|
||||
def __init__(self,
|
||||
pruner=None,
|
||||
start_epoch=0,
|
||||
end_epoch=10,
|
||||
delta_rate=0.20,
|
||||
acc_loss_threshold=0.2,
|
||||
sensitivities=None):
|
||||
super(SensitivePruneStrategy, self).__init__(start_epoch, end_epoch)
|
||||
self.pruner = pruner
|
||||
self.delta_rate = delta_rate
|
||||
self.acc_loss_threshold = acc_loss_threshold
|
||||
self.sensitivities = sensitivities
|
||||
|
||||
|
||||
class PruneStrategy(Strategy):
|
||||
"""
|
||||
The strategy that pruning weights by threshold or ratio iteratively.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pruner,
|
||||
mini_batch_pruning_frequency=1,
|
||||
start_epoch=0,
|
||||
end_epoch=10):
|
||||
super(PruneStrategy, self).__init__(start_epoch, end_epoch)
|
||||
self.pruner = pruner
|
||||
self.mini_batch_pruning_frequency = mini_batch_pruning_frequency
|
||||
|
||||
def _triger(self, context):
|
||||
return (context.batch_id % self.mini_batch_pruning_frequency == 0 and
|
||||
self.start_epoch <= context.epoch_id < self.end_epoch)
|
||||
|
||||
def on_batch_end(self, context):
|
||||
if self._triger(context):
|
||||
prune_program = Program()
|
||||
with program_guard(prune_program):
|
||||
for param in context.graph.all_parameters():
|
||||
prune_program.global_block().clone_variable(param)
|
||||
p = prune_program.global_block().var(param.name)
|
||||
zeros_mask = self.pruner.prune(p)
|
||||
pruned_param = p * zeros_mask
|
||||
layers.assign(input=pruned_param, output=param)
|
||||
context.program_exe.run(prune_program, scope=context.scope)
|
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from .... import layers
|
||||
|
||||
__all__ = ['Pruner', 'MagnitudePruner', 'RatioPruner']
|
||||
|
||||
|
||||
class Pruner(object):
|
||||
"""
|
||||
Base class of all pruners.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def prune(self, param):
|
||||
pass
|
||||
|
||||
|
||||
class MagnitudePruner(Pruner):
|
||||
"""
|
||||
Pruner used to pruning a parameter by threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, threshold):
|
||||
self.threshold = threshold
|
||||
|
||||
def prune(self, param, threshold=None):
|
||||
if threshold is None:
|
||||
thres = layers.fill_constant(
|
||||
shape=[1], dtype='float32', value=self.threshold)
|
||||
else:
|
||||
thres = threshold
|
||||
zeros_mask = layers.less_than(x=param, y=thres)
|
||||
return zeros_mask
|
||||
|
||||
|
||||
class RatioPruner(Pruner):
|
||||
"""
|
||||
Pruner used to pruning a parameter by ratio.
|
||||
"""
|
||||
|
||||
def __init__(self, ratios=None):
|
||||
"""
|
||||
Args:
|
||||
ratios: dict with pair (paramer_name, pruned_ratio).
|
||||
"""
|
||||
self.ratios = ratios
|
||||
|
||||
def prune(self, param, ratio=None):
|
||||
"""
|
||||
Args:
|
||||
ratio: `ratio=40%` means pruning (1 - 40%) weights to zero.
|
||||
"""
|
||||
if ratio is None:
|
||||
rat = self.ratios[
|
||||
param.name] if param.name in self.ratios else self.ratios['*']
|
||||
else:
|
||||
rat = ratio
|
||||
if rat < 1.0:
|
||||
k = max(int(rat * np.prod(param.shape)), 1)
|
||||
param_vec = layers.reshape(x=param, shape=[1, -1])
|
||||
param_topk, _ = layers.topk(param_vec, k=k)
|
||||
threshold = layers.slice(
|
||||
param_topk, axes=[1], starts=[-1], ends=[k])
|
||||
threshold = layers.reshape(x=threshold, shape=[1])
|
||||
zeros_mask = layers.less_than(x=param, y=threshold)
|
||||
else:
|
||||
zeros_mask = layers.ones(param.shape)
|
||||
return zeros_mask
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,29 @@
|
||||
version: 1.0
|
||||
include: ["./unitest/configs/pruners.yaml", "./unitest/configs/pruners_0.yaml"]
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.3
|
||||
'conv1_2.w': 0.4
|
||||
'*': 0.9
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
||||
strategies:
|
||||
strategy_1:
|
||||
class: 'SensitivePruneStrategy'
|
||||
pruner: 'pruner_2'
|
||||
start_epoch: 0
|
||||
end_epoch: 10
|
||||
delta_rate: 0.20
|
||||
acc_loss_threshold: 0.2
|
||||
sensitivities:
|
||||
'conv1_1.w': 0.4
|
||||
|
||||
compress_pass:
|
||||
class: 'CompressPass'
|
||||
epoch: 100
|
||||
strategies:
|
||||
- strategy_1
|
@ -0,0 +1,12 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_2:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.5
|
||||
'conv1_2.w': 0.2
|
||||
'*': 0.7
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
@ -0,0 +1,12 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_3:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.5
|
||||
'conv1_2.w': 0.2
|
||||
'*': 0.7
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.fluid.contrib.slim import ConfigFactory
|
||||
import unittest
|
||||
|
||||
|
||||
class TestFactory(unittest.TestCase):
|
||||
def test_parse(self):
|
||||
factory = ConfigFactory('./unitest/configs/config.yaml')
|
||||
|
||||
pruner = factory.instance('pruner_1')
|
||||
self.assertEquals(pruner.ratios['conv1_1.w'], 0.3)
|
||||
|
||||
pruner = factory.instance('pruner_2')
|
||||
self.assertEquals(pruner.ratios['*'], 0.7)
|
||||
|
||||
strategy = factory.instance('strategy_1')
|
||||
pruner = strategy.pruner
|
||||
self.assertEquals(pruner.ratios['*'], 0.7)
|
||||
|
||||
compress_pass = factory.get_compress_pass()
|
||||
self.assertEquals(compress_pass.epoch, 100)
|
||||
|
||||
strategy = compress_pass.strategies[0]
|
||||
self.assertEquals(strategy.delta_rate, 0.2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue