Remove slim from paddle framework (#25666)

* Remove slim from paddle framework
test=develop

Co-authored-by: wanghaoshuang <wanghaoshuang@baidu.com>
fix_copy_if_different
Bai Yifan 5 years ago committed by GitHub
parent bca303165a
commit 2131559d08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,6 @@ from .quantize import *
from . import reader
from .reader import *
from . import slim
from .slim import *
from . import utils
from .utils import *
from . import extend_optimizer
@ -43,7 +42,6 @@ __all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__
__all__ += quantize.__all__
__all__ += reader.__all__
__all__ += slim.__all__
__all__ += utils.__all__
__all__ += extend_optimizer.__all__
__all__ += ['mixed_precision']

@ -11,6 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .core import *
__all__ = ['Compressor', ]

@ -1,22 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import config
from .config import *
from . import compressor
from .compressor import *
from . import strategy
from .strategy import *
__all__ = config.__all__ + compressor.__all__ + strategy.__all__

File diff suppressed because it is too large Load Diff

@ -1,130 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import funcsigs
import yaml
from collections import OrderedDict
from ..prune import *
from ..quantization import *
from .strategy import *
from ..distillation import *
from ..searcher import *
from ..nas import *
__all__ = ['ConfigFactory']
"""This factory is used to create instances by loading and parsing configure file with yaml format.
"""
PLUGINS = ['pruners', 'quantizers', 'distillers', 'strategies', 'controllers']
class ConfigFactory(object):
def __init__(self, config):
"""Init a factory from configure file."""
self.instances = {}
self.compressor = {}
self.version = None
self._parse_config(config)
def instance(self, name):
"""
Get instance from factory.
"""
if name in self.instances:
return self.instances[name]
else:
return None
def _new_instance(self, name, attrs):
if name not in self.instances:
class_ = globals()[attrs['class']]
sig = funcsigs.signature(class_.__init__)
keys = [
param.name for param in sig.parameters.values()
if (param.kind == param.POSITIONAL_OR_KEYWORD)
][1:]
keys = set(attrs.keys()).intersection(set(keys))
args = {}
for key in keys:
value = attrs[key]
if isinstance(value, str) and value.lower() == 'none':
value = None
if isinstance(value, str) and value in self.instances:
value = self.instances[value]
if isinstance(value, list):
for i in range(len(value)):
if isinstance(value[i],
str) and value[i] in self.instances:
value[i] = self.instances[value[i]]
args[key] = value
self.instances[name] = class_(**args)
return self.instances.get(name)
def _parse_config(self, config):
assert config
with open(config, 'r') as config_file:
key_values = self._ordered_load(config_file)
for key in key_values:
# parse version
if key == 'version' and self.version is None:
self.version = int(key_values['version'])
assert self.version == int(key_values['version'])
# parse pruners
if key in PLUGINS:
instances = key_values[key]
for name in instances:
self._new_instance(name, instances[name])
if key == 'compressor':
self.compressor['strategies'] = []
self.compressor['epoch'] = key_values[key]['epoch']
if 'init_model' in key_values[key]:
self.compressor['init_model'] = key_values[key][
'init_model']
if 'checkpoint_path' in key_values[key]:
self.compressor['checkpoint_path'] = key_values[key][
'checkpoint_path']
if 'eval_epoch' in key_values[key]:
self.compressor['eval_epoch'] = key_values[key][
'eval_epoch']
if 'strategies' in key_values[key]:
for name in key_values[key]['strategies']:
strategy = self.instance(name)
self.compressor['strategies'].append(strategy)
if key == 'include':
for config_file in key_values[key]:
self._parse_config(config_file.strip())
def _ordered_load(self,
stream,
Loader=yaml.Loader,
object_pairs_hook=OrderedDict):
"""
See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts
"""
class OrderedLoader(Loader):
pass
def construct_mapping(loader, node):
loader.flatten_mapping(node)
return object_pairs_hook(loader.construct_pairs(node))
OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
return yaml.load(stream, OrderedLoader)

@ -1,58 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['Strategy']
class Strategy(object):
"""
Base class for all strategies.
"""
def __init__(self, start_epoch=0, end_epoch=0):
"""
Args:
start_epoch: The first epoch to apply the strategy.
end_epoch: The last epoch to apply the strategy.
"""
self.start_epoch = start_epoch
self.end_epoch = end_epoch
def __getstate__(self):
d = {}
for key in self.__dict__:
if key not in ["start_epoch", "end_epoch"]:
d[key] = self.__dict__[key]
return d
def on_compression_begin(self, context):
pass
def on_epoch_begin(self, context):
pass
def on_epoch_end(self, context):
pass
def on_batch_begin(self, context):
pass
def on_batch_end(self, context):
pass
def on_compression_end(self, context):
pass
def restore_from_checkpoint(self, context):
pass

@ -1,21 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import distiller
from .distiller import *
from . import distillation_strategy
from .distillation_strategy import *
__all__ = distiller.__all__
__all__ += distillation_strategy.__all__

@ -1,104 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core.strategy import Strategy
from ....framework import Program, Variable, program_guard
from ....log_helper import get_logger
from .... import Executor
import logging
__all__ = ['DistillationStrategy']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class DistillationStrategy(Strategy):
def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
"""
Args:
distillers(list): A list of distiller used to combine student graph and teacher graph
by adding some loss.
start_epoch(int): The epoch when to merge student graph and teacher graph for
distillation training. default: 0
end_epoch(int): The epoch when to finish distillation training. default: 0
"""
super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
self.distillers = distillers
def restore_from_checkpoint(self, context):
# load from checkpoint
if context.epoch_id > 0:
if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:
_logger.info('Restore DistillationStrategy')
self._create_distillation_graph(context)
_logger.info('Restore DistillationStrategy finish.')
def on_epoch_begin(self, context):
if self.start_epoch == context.epoch_id:
_logger.info('DistillationStrategy::on_epoch_begin.')
self._create_distillation_graph(context)
_logger.info('DistillationStrategy set optimize_graph.')
def _create_distillation_graph(self, context):
"""
step 1: Merge student graph and teacher graph into distillation graph.
step 2: Add loss into distillation graph by distillers.
step 3: Append backward ops and optimize ops into distillation graph for training.
"""
# step 1
teacher = context.teacher_graphs[0]
for var in teacher.program.list_vars():
var.stop_gradient = True
graph = context.train_graph.clone()
graph.merge(teacher)
if 'loss' in graph.out_nodes:
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
# step 2
for distiller in self.distillers:
graph = distiller.distiller_loss(graph)
# step 3
startup_program = Program()
with program_guard(graph.program, startup_program):
context.distiller_optimizer._name = 'distillation_optimizer'
# The learning rate variable may be created in other program.
# Update information in optimizer to make
# learning rate variable being accessible in current program.
optimizer = context.distiller_optimizer
if isinstance(optimizer._learning_rate, Variable):
optimizer._learning_rate_map[
graph.program] = optimizer._learning_rate
optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)
exe = Executor(context.place)
exe.run(startup_program, scope=context.scope)
# backup graph for fine-tune after distillation
context.put('distillation_backup_optimize_graph',
context.optimize_graph)
context.optimize_graph = graph
def on_epoch_end(self, context):
if context.epoch_id == (self.end_epoch - 1):
_logger.info('DistillationStrategy::on_epoch_end.')
# restore optimize_graph for fine-tune or other strategy in next stage.
context.optimize_graph = context.get(
'distillation_backup_optimize_graph')
_logger.info(
'DistillationStrategy set context.optimize_graph to None.')

File diff suppressed because it is too large Load Diff

@ -1,20 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import executor
from .executor import *
from . import graph_wrapper
from .graph_wrapper import *
__all__ = executor.__all__
__all__ += graph_wrapper.__all__

@ -1,61 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ....compiler import CompiledProgram
from ....data_feeder import DataFeeder
from .... import executor
from .graph_wrapper import GraphWrapper
__all__ = ['SlimGraphExecutor']
class SlimGraphExecutor(object):
"""
Wrapper of executor used to run GraphWrapper.
"""
def __init__(self, place):
self.exe = executor.Executor(place)
self.place = place
def run(self, graph, scope, data=None):
"""
Runing a graph with a batch of data.
Args:
graph(GraphWrapper): The graph to be executed.
scope(fluid.core.Scope): The scope to be used.
data(list<tuple>): A batch of data. Each tuple in this list is a sample.
It will feed the items of tuple to the in_nodes of graph.
Returns:
results(list): A list of result with the same order indicated by graph.out_nodes.
"""
assert isinstance(graph, GraphWrapper)
feed = None
if data is not None and isinstance(data[0], dict):
# return list = False
feed = data
elif data is not None:
feeder = DataFeeder(
feed_list=list(graph.in_nodes.values()),
place=self.place,
program=graph.program)
feed = feeder.feed(data)
fetch_list = list(graph.out_nodes.values())
program = graph.compiled_graph if graph.compiled_graph else graph.program
results = self.exe.run(program,
scope=scope,
fetch_list=fetch_list,
feed=feed)
return results

File diff suppressed because it is too large Load Diff

@ -1,29 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import light_nas_strategy
from .light_nas_strategy import *
from . import controller_server
from .controller_server import *
from . import search_agent
from .search_agent import *
from . import search_space
from .search_space import *
from . import lock
from .lock import *
__all__ = light_nas_strategy.__all__
__all__ += controller_server.__all__
__all__ += search_agent.__all__
__all__ += search_space.__all__

@ -1,107 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from threading import Thread
from ....log_helper import get_logger
__all__ = ['ControllerServer']
_logger = get_logger(
__name__,
logging.INFO,
fmt='ControllerServer-%(asctime)s-%(levelname)s: %(message)s')
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agent.
"""
def __init__(self,
controller=None,
address=('', 0),
max_client_num=100,
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
"""
self._controller = controller
self._address = address
self._max_client_num = max_client_num
self._search_steps = search_steps
self._closed = False
self._port = address[1]
self._ip = address[0]
self._key = key
def start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("listen on: [{}:{}]".format(self._ip, self._port))
thread = Thread(target=self.run)
thread.start()
return str(thread)
def close(self):
"""Close the server."""
self._closed = True
def port(self):
"""Get the port."""
return self._port
def ip(self):
"""Get the ip."""
return self._ip
def run(self):
_logger.info("Controller Server run...")
while ((self._search_steps is None) or
(self._controller._iter <
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
else:
_logger.info("recv message from {}: [{}]".format(addr, message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
_logger.info("recv noise from {}: [{}]".format(addr,
message))
continue
tokens = messages[1]
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
_logger.info("send message to {}: [{}]".format(addr, tokens))
conn.close()
self._socket_server.close()
_logger.info("server closed!")

@ -1,196 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core.strategy import Strategy
from ..graph import GraphWrapper
from .controller_server import ControllerServer
from .search_agent import SearchAgent
from ....executor import Executor
from ....log_helper import get_logger
import re
import logging
import functools
import socket
from .lock import lock, unlock
__all__ = ['LightNASStrategy']
_logger = get_logger(
__name__,
logging.INFO,
fmt='LightNASStrategy-%(asctime)s-%(levelname)s: %(message)s')
class LightNASStrategy(Strategy):
"""
Light-NAS search strategy.
"""
def __init__(self,
controller=None,
end_epoch=1000,
target_flops=629145600,
target_latency=0,
retrain_epoch=1,
metric_name='top1_acc',
server_ip=None,
server_port=0,
is_server=False,
max_client_num=100,
search_steps=None,
key="light-nas"):
"""
Args:
controller(searcher.Controller): The searching controller. Default: None.
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
target_flops(int): The constraint of FLOPS.
target_latency(float): The constraint of latency.
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
metric_name(str): The metric used to evaluate the model.
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
is_server(bool): Whether current host is controller server. Default: False.
max_client_num(int): The maximum number of clients that connect to controller server concurrently. Default: 100.
search_steps(int): The total steps of searching. Default: None.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.start_epoch = 0
self.end_epoch = end_epoch
self._max_flops = target_flops
self._max_latency = target_latency
self._metric_name = metric_name
self._controller = controller
self._retrain_epoch = 0
self._server_ip = server_ip
self._server_port = server_port
self._is_server = is_server
self._retrain_epoch = retrain_epoch
self._search_steps = search_steps
self._max_client_num = max_client_num
self._max_try_times = 100
self._key = key
if self._server_ip is None:
self._server_ip = self._get_host_ip()
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def on_compression_begin(self, context):
self._current_tokens = context.search_space.init_tokens()
self._controller.reset(context.search_space.range_table(),
self._current_tokens, None)
# create controller server
if self._is_server:
open("./slim_LightNASStrategy_controller_server.socket",
'a').close()
socket_file = open(
"./slim_LightNASStrategy_controller_server.socket", 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
self._server = ControllerServer(
controller=self._controller,
address=(self._server_ip, self._server_port),
max_client_num=self._max_client_num,
search_steps=self._search_steps,
key=self._key)
tid = self._server.start()
self._server_port = self._server.port()
socket_file.write(tid)
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
_logger.info("self._server_ip: {}; self._server_port: {}".format(
self._server_ip, self._server_port))
# create client
self._search_agent = SearchAgent(
self._server_ip, self._server_port, key=self._key)
def __getstate__(self):
"""Socket can't be pickled."""
d = {}
for key in self.__dict__:
if key not in ["_search_agent", "_server"]:
d[key] = self.__dict__[key]
return d
def on_epoch_begin(self, context):
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
_logger.info("light nas strategy on_epoch_begin")
min_flops = -1
for _ in range(self._max_try_times):
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
self._current_tokens)
context.eval_graph.program = test_p
flops = context.eval_graph.flops()
if min_flops == -1:
min_flops = flops
min_tokens = self._current_tokens[:]
else:
if flops < min_flops:
min_tokens = self._current_tokens[:]
if self._max_latency > 0:
latency = context.search_space.get_model_latency(test_p)
_logger.info("try [{}] with latency {} flops {}".format(
self._current_tokens, latency, flops))
else:
_logger.info("try [{}] with flops {}".format(
self._current_tokens, flops))
if flops > self._max_flops or (self._max_latency > 0 and
latency > self._max_latency):
self._current_tokens = self._controller.next_tokens(
min_tokens)
else:
break
context.train_reader = train_reader
context.eval_reader = test_reader
exe = Executor(context.place)
exe.run(startup_p)
context.optimize_graph.program = train_p
context.optimize_graph.compile()
context.skip_training = (self._retrain_epoch == 0)
def on_epoch_end(self, context):
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch + 1
) % self._retrain_epoch == 0):
self._current_reward = context.eval_results[self._metric_name][-1]
flops = context.eval_graph.flops()
if flops > self._max_flops:
self._current_reward = 0.0
if self._max_latency > 0:
test_p = context.search_space.create_net(self._current_tokens)[
2]
latency = context.search_space.get_model_latency(test_p)
if latency > self._max_latency:
self._current_reward = 0.0
_logger.info("reward: {}; latency: {}; flops: {}; tokens: {}".
format(self._current_reward, latency, flops,
self._current_tokens))
else:
_logger.info("reward: {}; flops: {}; tokens: {}".format(
self._current_reward, flops, self._current_tokens))
self._current_tokens = self._search_agent.update(
self._current_tokens, self._current_reward)

@ -1,36 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__All__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")

@ -1,67 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from ....log_helper import get_logger
__all__ = ['SearchAgent']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class SearchAgent(object):
"""
Search agent.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def next_tokens(self):
"""
Get next tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("next_tokens".encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens

@ -1,52 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The search space used to search neural architecture"""
__all__ = ['SearchSpace']
class SearchSpace(object):
"""Controller for Neural Architecture Search.
"""
def __init__(self, *args, **kwargs):
pass
def init_tokens(self):
"""Get init tokens in search space.
"""
raise NotImplementedError('Abstract method.')
def range_table(self):
"""Get range table of current search space.
"""
raise NotImplementedError('Abstract method.')
def create_net(self, tokens):
"""Create networks for training and evaluation according to tokens.
Args:
tokens(list<int>): The tokens which represent a network.
Return:
(tuple): startup_program, train_program, evaluation_program, train_metrics, test_metrics
"""
raise NotImplementedError('Abstract method.')
def get_model_latency(self, program):
"""Get model latency according to program.
Args:
program(Program): The program to get latency.
Return:
(float): model latency.
"""
raise NotImplementedError('Abstract method.')

@ -1,24 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import pruner
from .pruner import *
from . import prune_strategy
from .prune_strategy import *
from . import auto_prune_strategy
from .auto_prune_strategy import *
__all__ = pruner.__all__
__all__ += prune_strategy.__all__
__all__ += auto_prune_strategy.__all__

File diff suppressed because it is too large Load Diff

@ -1,107 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import collections
from .... import layers
__all__ = ['Pruner', 'StructurePruner']
class Pruner(object):
"""
Base class of all pruners.
"""
def __init__(self):
pass
def prune(self, param):
pass
class StructurePruner(Pruner):
"""
Pruner used to pruning parameters by groups.
"""
def __init__(self, pruning_axis, criterions):
"""
Args:
pruning_axis(dict): The key is the name of parameter to be pruned,
'*' means all the parameters.
The value is the axis to be used. Given a parameter
with shape [3, 4], the result of pruning 50% on axis 1
is a parameter with shape [3, 2].
criterions(dict): The key is the name of parameter to be pruned,
'*' means all the parameters.
The value is the criterion used to sort groups for pruning.
It only supports 'l1_norm' currently.
"""
self.pruning_axis = pruning_axis
self.criterions = criterions
def cal_pruned_idx(self, name, param, ratio, axis=None):
"""
Calculate the index to be pruned on axis by given pruning ratio.
Args:
name(str): The name of parameter to be pruned.
param(np.array): The data of parameter to be pruned.
ratio(float): The ratio to be pruned.
axis(int): The axis to be used for pruning given parameter.
If it is None, the value in self.pruning_axis will be used.
default: None.
Returns:
list<int>: The indexes to be pruned on axis.
"""
criterion = self.criterions[
name] if name in self.criterions else self.criterions['*']
if axis is None:
assert self.pruning_axis is not None, "pruning_axis should set if axis is None."
axis = self.pruning_axis[
name] if name in self.pruning_axis else self.pruning_axis['*']
prune_num = int(round(param.shape[axis] * ratio))
reduce_dims = [i for i in range(len(param.shape)) if i != axis]
if criterion == 'l1_norm':
criterions = np.sum(np.abs(param), axis=tuple(reduce_dims))
pruned_idx = criterions.argsort()[:prune_num]
return pruned_idx
def prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
"""
Pruning a array by indexes on given axis.
Args:
tensor(numpy.array): The target array to be pruned.
pruned_idx(list<int>): The indexes to be pruned.
pruned_axis(int): The axis of given array to be pruned on.
lazy(bool): True means setting the pruned elements to zero.
False means remove the pruned elements from memory.
default: False.
Returns:
numpy.array: The pruned array.
"""
mask = np.zeros(tensor.shape[pruned_axis], dtype=bool)
mask[pruned_idx] = True
def func(data):
return data[~mask]
def lazy_func(data):
data[mask] = 0
return data
if lazy:
return np.apply_along_axis(lazy_func, pruned_axis, tensor)
else:
return np.apply_along_axis(func, pruned_axis, tensor)

@ -16,10 +16,6 @@ from __future__ import print_function
from . import quantization_pass
from .quantization_pass import *
from . import quantization_strategy
from .quantization_strategy import *
from . import mkldnn_post_training_strategy
from .mkldnn_post_training_strategy import *
from . import quant_int8_mkldnn_pass
from .quant_int8_mkldnn_pass import *
from . import quant2_int8_mkldnn_pass
@ -29,8 +25,7 @@ from .post_training_quantization import *
from . import imperative
from .imperative import *
__all__ = quantization_pass.__all__ + quantization_strategy.__all__
__all__ += mkldnn_post_training_strategy.__all__
__all__ = quantization_pass.__all__
__all__ += quant_int8_mkldnn_pass.__all__
__all__ += quant2_int8_mkldnn_pass.__all__
__all__ += post_training_quantization.__all__

@ -1,113 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import six
import numpy as np
from .... import core
from ..core.strategy import Strategy
from ....log_helper import get_logger
__all__ = ['MKLDNNPostTrainingQuantStrategy']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class MKLDNNPostTrainingQuantStrategy(Strategy):
"""
The strategy for MKL-DNN Post Training quantization strategy.
"""
def __init__(self,
int8_model_save_path=None,
fp32_model_path=None,
cpu_math_library_num_threads=1):
"""
Args:
int8_model_save_path(str): int8_model_save_path is used to save an int8 ProgramDesc
with fp32 weights which is used for MKL-DNN int8 inference. For post training quantization,
MKLDNNPostTrainingQuantStrategy only supports converting a fp32 ProgramDesc
with fp32 weights to an int8 ProgramDesc with fp32 weights now. The saved
int8 ProgramDesc with fp32 weights only can be executed with MKL-DNN enabled.
None means it doesn't save int8 ProgramDesc with fp32 weights. default: None.
fp32_model_path(str): fp32_model_path is used to load an original fp32 ProgramDesc with fp32 weights.
None means it doesn't have a fp32 ProgramDesc with fp32 weights. default: None.
cpu_math_library_num_threads(int): The number of cpu math library threads which is used on
MKLDNNPostTrainingQuantStrategy. 1 means it only uses one cpu math library
thread. default: 1
"""
super(MKLDNNPostTrainingQuantStrategy, self).__init__(0, 0)
self.int8_model_save_path = int8_model_save_path
if fp32_model_path is None:
raise Exception("fp32_model_path is None")
self.fp32_model_path = fp32_model_path
self.cpu_math_library_num_threads = cpu_math_library_num_threads
def on_compression_begin(self, context):
"""
Prepare the data and quantify the model
"""
super(MKLDNNPostTrainingQuantStrategy,
self).on_compression_begin(context)
_logger.info('InferQuantStrategy::on_compression_begin')
# Prepare the Analysis Config
infer_config = core.AnalysisConfig("AnalysisConfig")
infer_config.switch_ir_optim(True)
infer_config.disable_gpu()
infer_config.set_model(self.fp32_model_path)
infer_config.enable_mkldnn()
infer_config.set_cpu_math_library_num_threads(
self.cpu_math_library_num_threads)
# Prepare the data for calculating the quantization scales
warmup_reader = context.eval_reader()
if six.PY2:
data = warmup_reader.next()
if six.PY3:
data = warmup_reader.__next__()
num_images = len(data)
image_data = [img.tolist() for (img, _) in data]
image_data = np.array(image_data).astype("float32").reshape(
[num_images, ] + list(data[0][0].shape))
image_data = image_data.ravel()
images = core.PaddleTensor(image_data, "x")
images.shape = [num_images, ] + list(data[0][0].shape)
label_data = [label for (_, label) in data]
labels = core.PaddleTensor(
np.array(label_data).astype("int64").reshape([num_images, 1]), "y")
warmup_data = [images, labels]
# Enable the INT8 Quantization
infer_config.enable_quantizer()
infer_config.quantizer_config().set_quant_data(warmup_data)
infer_config.quantizer_config().set_quant_batch_size(num_images)
# Run INT8 MKL-DNN Quantization
predictor = core.create_paddle_predictor(infer_config)
if self.int8_model_save_path:
if not os.path.exists(self.int8_model_save_path):
os.makedirs(self.int8_model_save_path)
predictor.SaveOptimModel(self.int8_model_save_path)
_logger.info(
'Finish MKLDNNPostTrainingQuantStrategy::on_compresseion_begin')

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save