Add Ligth-NAS for PaddleSlim (#17679)
* Add auto pruning strategy. 1. Fix compressor. 2. Enhence graph executor. 3. Add SAController 4. Add auto pruning strategy. 5. Add unitest for auto pruning strategy. test=develop * Init light-nas * Add light nas. * Some fix. test=develop * Fix sa controller. test=develop * Fix unitest of light nas. test=develop * Fix setup.py.in and API.spec. test=develop * Fix unitest. 1. Fix unitest on windows. 2. Fix package importing in tests directory. * 1. Remove unused comments. 2. Expose eval_epoch option. 3. Remove unused function in search_agent. 4. Expose max_client_num to yaml file. 5. Move flops constraint to on_epoch_begin function test=develop * Fix light nas strategy. test=develop * Make controller server stable. test=develop * 1. Add try exception to compressor. 2. Remove unitest of light-nas for windows. test=develop * Add comments Enhence controller test=develop * Fix comments. test=developdependabot/pip/python/requests-2.20.0
parent
3925bd81e8
commit
5df65e506d
@ -0,0 +1,29 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import light_nas_strategy
|
||||
from .light_nas_strategy import *
|
||||
from . import controller_server
|
||||
from .controller_server import *
|
||||
from . import search_agent
|
||||
from .search_agent import *
|
||||
from . import search_space
|
||||
from .search_space import *
|
||||
from . import lock
|
||||
from .lock import *
|
||||
|
||||
__all__ = light_nas_strategy.__all__
|
||||
__all__ += controller_server.__all__
|
||||
__all__ += search_agent.__all__
|
||||
__all__ += search_space.__all__
|
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import socket
|
||||
from threading import Thread
|
||||
|
||||
__all__ = ['ControllerServer']
|
||||
|
||||
logging.basicConfig(
|
||||
format='ControllerServer-%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class ControllerServer(object):
|
||||
"""
|
||||
The controller wrapper with a socket server to handle the request of search agentt.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
controller=None,
|
||||
address=('', 0),
|
||||
max_client_num=100,
|
||||
search_steps=None,
|
||||
key=None):
|
||||
"""
|
||||
Args:
|
||||
controller(slim.searcher.Controller): The controller used to generate tokens.
|
||||
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
|
||||
which means setting ip automatically
|
||||
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
|
||||
search_steps(int): The total steps of searching. None means never stopping. Default: None
|
||||
"""
|
||||
self._controller = controller
|
||||
self._address = address
|
||||
self._max_client_num = max_client_num
|
||||
self._search_steps = search_steps
|
||||
self._closed = False
|
||||
self._port = address[1]
|
||||
self._ip = address[0]
|
||||
self._key = key
|
||||
|
||||
def start(self):
|
||||
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket_server.bind(self._address)
|
||||
self._socket_server.listen(self._max_client_num)
|
||||
self._port = self._socket_server.getsockname()[1]
|
||||
self._ip = self._socket_server.getsockname()[0]
|
||||
_logger.info("listen on: [{}:{}]".format(self._ip, self._port))
|
||||
thread = Thread(target=self.run)
|
||||
thread.start()
|
||||
return str(thread)
|
||||
|
||||
def close(self):
|
||||
"""Close the server."""
|
||||
self._closed = True
|
||||
|
||||
def port(self):
|
||||
"""Get the port."""
|
||||
return self._port
|
||||
|
||||
def ip(self):
|
||||
"""Get the ip."""
|
||||
return self._ip
|
||||
|
||||
def run(self):
|
||||
_logger.info("Controller Server run...")
|
||||
while ((self._search_steps is None) or
|
||||
(self._controller._iter <
|
||||
(self._search_steps))) and not self._closed:
|
||||
conn, addr = self._socket_server.accept()
|
||||
message = conn.recv(1024).decode()
|
||||
if message.strip("\n") == "next_tokens":
|
||||
tokens = self._controller.next_tokens()
|
||||
tokens = ",".join([str(token) for token in tokens])
|
||||
conn.send(tokens.encode())
|
||||
else:
|
||||
_logger.info("recv message from {}: [{}]".format(addr, message))
|
||||
messages = message.strip('\n').split("\t")
|
||||
if (len(messages) < 3) or (messages[0] != self._key):
|
||||
_logger.info("recv noise from {}: [{}]".format(addr,
|
||||
message))
|
||||
continue
|
||||
tokens = messages[1]
|
||||
reward = messages[2]
|
||||
tokens = [int(token) for token in tokens.split(",")]
|
||||
self._controller.update(tokens, float(reward))
|
||||
tokens = self._controller.next_tokens()
|
||||
tokens = ",".join([str(token) for token in tokens])
|
||||
conn.send(tokens.encode())
|
||||
_logger.info("send message to {}: [{}]".format(addr, tokens))
|
||||
conn.close()
|
||||
self._socket_server.close()
|
||||
_logger.info("server closed!")
|
@ -0,0 +1,178 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..core.strategy import Strategy
|
||||
from ..graph import GraphWrapper
|
||||
from .controller_server import ControllerServer
|
||||
from .search_agent import SearchAgent
|
||||
from ....executor import Executor
|
||||
import re
|
||||
import logging
|
||||
import functools
|
||||
import socket
|
||||
from .lock import lock, unlock
|
||||
|
||||
__all__ = ['LightNASStrategy']
|
||||
|
||||
logging.basicConfig(
|
||||
format='LightNASStrategy-%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class LightNASStrategy(Strategy):
|
||||
"""
|
||||
Light-NAS search strategy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
controller=None,
|
||||
end_epoch=1000,
|
||||
target_flops=629145600,
|
||||
retrain_epoch=1,
|
||||
metric_name='top1_acc',
|
||||
server_ip=None,
|
||||
server_port=0,
|
||||
is_server=False,
|
||||
max_client_num=100,
|
||||
search_steps=None,
|
||||
key="light-nas"):
|
||||
"""
|
||||
Args:
|
||||
controller(searcher.Controller): The searching controller. Default: None.
|
||||
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
|
||||
target_flops(int): The constraint of FLOPS.
|
||||
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
|
||||
metric_name(str): The metric used to evaluate the model.
|
||||
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
|
||||
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
|
||||
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
|
||||
is_server(bool): Whether current host is controller server. Default: False.
|
||||
max_client_num(int): The maximum number of clients that connect to controller server concurrently. Default: 100.
|
||||
search_steps(int): The total steps of searching. Default: None.
|
||||
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
|
||||
"""
|
||||
self.start_epoch = 0
|
||||
self.end_epoch = end_epoch
|
||||
self._max_flops = target_flops
|
||||
self._metric_name = metric_name
|
||||
self._controller = controller
|
||||
self._retrain_epoch = 0
|
||||
self._server_ip = server_ip
|
||||
self._server_port = server_port
|
||||
self._is_server = is_server
|
||||
self._retrain_epoch = retrain_epoch
|
||||
self._search_steps = search_steps
|
||||
self._max_client_num = max_client_num
|
||||
self._max_try_times = 100
|
||||
self._key = key
|
||||
|
||||
if self._server_ip is None:
|
||||
self._server_ip = self._get_host_ip()
|
||||
|
||||
def _get_host_ip(self):
|
||||
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(('8.8.8.8', 80))
|
||||
ip = s.getsockname()[0]
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
return ip
|
||||
|
||||
def on_compression_begin(self, context):
|
||||
self._current_tokens = context.search_space.init_tokens()
|
||||
constrain_func = functools.partial(
|
||||
self._constrain_func, context=context)
|
||||
self._controller.reset(context.search_space.range_table(),
|
||||
self._current_tokens, None)
|
||||
|
||||
# create controller server
|
||||
if self._is_server:
|
||||
open("./slim_LightNASStrategy_controller_server.socket",
|
||||
'a').close()
|
||||
socket_file = open(
|
||||
"./slim_LightNASStrategy_controller_server.socket", 'r+')
|
||||
lock(socket_file)
|
||||
tid = socket_file.readline()
|
||||
if tid == '':
|
||||
_logger.info("start controller server...")
|
||||
self._server = ControllerServer(
|
||||
controller=self._controller,
|
||||
address=(self._server_ip, self._server_port),
|
||||
max_client_num=self._max_client_num,
|
||||
search_steps=self._search_steps,
|
||||
key=self._key)
|
||||
tid = self._server.start()
|
||||
self._server_port = self._server.port()
|
||||
socket_file.write(tid)
|
||||
_logger.info("started controller server...")
|
||||
unlock(socket_file)
|
||||
socket_file.close()
|
||||
_logger.info("self._server_ip: {}; self._server_port: {}".format(
|
||||
self._server_ip, self._server_port))
|
||||
# create client
|
||||
self._search_agent = SearchAgent(
|
||||
self._server_ip, self._server_port, key=self._key)
|
||||
|
||||
def _constrain_func(self, tokens, context=None):
|
||||
"""Check whether the tokens meet constraint."""
|
||||
_, _, test_prog, _, _, _, _ = context.search_space.create_net(tokens)
|
||||
flops = GraphWrapper(test_prog).flops()
|
||||
if flops <= self._max_flops:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
|
||||
self._retrain_epoch == 0 or
|
||||
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
|
||||
_logger.info("light nas strategy on_epoch_begin")
|
||||
for _ in range(self._max_try_times):
|
||||
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
|
||||
self._current_tokens)
|
||||
_logger.info("try [{}]".format(self._current_tokens))
|
||||
context.eval_graph.program = test_p
|
||||
flops = context.eval_graph.flops()
|
||||
if flops <= self._max_flops:
|
||||
break
|
||||
else:
|
||||
self._current_tokens = self._search_agent.next_tokens()
|
||||
|
||||
context.train_reader = train_reader
|
||||
context.eval_reader = test_reader
|
||||
|
||||
exe = Executor(context.place)
|
||||
exe.run(startup_p)
|
||||
|
||||
context.optimize_graph.program = train_p
|
||||
context.optimize_graph.compile()
|
||||
|
||||
context.skip_training = (self._retrain_epoch == 0)
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
|
||||
self._retrain_epoch == 0 or
|
||||
(context.epoch_id - self.start_epoch + 1
|
||||
) % self._retrain_epoch == 0):
|
||||
|
||||
self._current_reward = context.eval_results[self._metric_name][-1]
|
||||
flops = context.eval_graph.flops()
|
||||
if flops > self._max_flops:
|
||||
self._current_reward = 0.0
|
||||
_logger.info("reward: {}; flops: {}; tokens: {}".format(
|
||||
self._current_reward, flops, self._current_tokens))
|
||||
self._current_tokens = self._search_agent.update(
|
||||
self._current_tokens, self._current_reward)
|
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
__All__ = ['lock', 'unlock']
|
||||
if os.name == 'nt':
|
||||
|
||||
def lock(file):
|
||||
raise NotImplementedError('Windows is not supported.')
|
||||
|
||||
def unlock(file):
|
||||
raise NotImplementedError('Windows is not supported.')
|
||||
|
||||
elif os.name == 'posix':
|
||||
from fcntl import flock, LOCK_EX, LOCK_UN
|
||||
|
||||
def lock(file):
|
||||
"""Lock the file in local file system."""
|
||||
flock(file.fileno(), LOCK_EX)
|
||||
|
||||
def unlock(file):
|
||||
"""Unlock the file in local file system."""
|
||||
flock(file.fileno(), LOCK_UN)
|
||||
else:
|
||||
raise RuntimeError("File Locker only support NT and Posix platforms!")
|
@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import socket
|
||||
|
||||
__all__ = ['SearchAgent']
|
||||
|
||||
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class SearchAgent(object):
|
||||
"""
|
||||
Search agent.
|
||||
"""
|
||||
|
||||
def __init__(self, server_ip=None, server_port=None, key=None):
|
||||
"""
|
||||
Args:
|
||||
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
|
||||
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
|
||||
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
|
||||
"""
|
||||
self.server_ip = server_ip
|
||||
self.server_port = server_port
|
||||
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._key = key
|
||||
|
||||
def update(self, tokens, reward):
|
||||
"""
|
||||
Update the controller according to latest tokens and reward.
|
||||
Args:
|
||||
tokens(list<int>): The tokens generated in last step.
|
||||
reward(float): The reward of tokens.
|
||||
"""
|
||||
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
socket_client.connect((self.server_ip, self.server_port))
|
||||
tokens = ",".join([str(token) for token in tokens])
|
||||
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
|
||||
.encode())
|
||||
tokens = socket_client.recv(1024).decode()
|
||||
tokens = [int(token) for token in tokens.strip("\n").split(",")]
|
||||
return tokens
|
||||
|
||||
def next_tokens(self):
|
||||
"""
|
||||
Get next tokens.
|
||||
"""
|
||||
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
socket_client.connect((self.server_ip, self.server_port))
|
||||
socket_client.send("next_tokens".encode())
|
||||
tokens = socket_client.recv(1024).decode()
|
||||
tokens = [int(token) for token in tokens.strip("\n").split(",")]
|
||||
return tokens
|
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The search space used to search neural architecture"""
|
||||
|
||||
__all__ = ['SearchSpace']
|
||||
|
||||
|
||||
class SearchSpace(object):
|
||||
"""Controller for Neural Architecture Search.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def init_tokens(self):
|
||||
"""Get init tokens in search space.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
def range_table(self):
|
||||
"""Get range table of current search space.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
def create_net(self, tokens):
|
||||
"""Create networks for training and evaluation according to tokens.
|
||||
Args:
|
||||
tokens(list<int>): The tokens which represent a network.
|
||||
Return:
|
||||
(tuple): startup_program, train_program, evaluation_program, train_metrics, test_metrics
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
@ -0,0 +1,249 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .prune_strategy import PruneStrategy
|
||||
import re
|
||||
import logging
|
||||
import functools
|
||||
import copy
|
||||
|
||||
__all__ = ['AutoPruneStrategy']
|
||||
|
||||
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class AutoPruneStrategy(PruneStrategy):
|
||||
"""
|
||||
Automatic pruning strategy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pruner=None,
|
||||
controller=None,
|
||||
start_epoch=0,
|
||||
end_epoch=10,
|
||||
min_ratio=0.5,
|
||||
max_ratio=0.7,
|
||||
metric_name='top1_acc',
|
||||
pruned_params='conv.*_weights',
|
||||
retrain_epoch=0):
|
||||
"""
|
||||
Args:
|
||||
pruner(slim.Pruner): The pruner used to prune the parameters. Default: None.
|
||||
controller(searcher.Controller): The searching controller. Default: None.
|
||||
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. Default: 0
|
||||
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
|
||||
min_ratio(float): The maximum pruned ratio. Default: 0.7
|
||||
max_ratio(float): The minimum pruned ratio. Default: 0.5
|
||||
metric_name(str): The metric used to evaluate the model.
|
||||
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
|
||||
pruned_params(str): The pattern str to match the parameter names to be pruned. Default: 'conv.*_weights'
|
||||
retrain_epoch(int): The training epochs in each seaching step. Default: 0
|
||||
"""
|
||||
super(AutoPruneStrategy, self).__init__(pruner, start_epoch, end_epoch,
|
||||
0.0, metric_name, pruned_params)
|
||||
self._max_ratio = max_ratio
|
||||
self._min_ratio = min_ratio
|
||||
self._controller = controller
|
||||
self._metric_name = metric_name
|
||||
self._pruned_param_names = []
|
||||
self._retrain_epoch = 0
|
||||
|
||||
self._current_tokens = None
|
||||
|
||||
def on_compression_begin(self, context):
|
||||
"""
|
||||
Prepare some information for searching strategy.
|
||||
step 1: Find all the parameters to be pruned.
|
||||
step 2: Get initial tokens and setup controller.
|
||||
"""
|
||||
pruned_params = []
|
||||
for param in context.eval_graph.all_parameters():
|
||||
if re.match(self.pruned_params, param.name()):
|
||||
self._pruned_param_names.append(param.name())
|
||||
|
||||
self._current_tokens = self._get_init_tokens(context)
|
||||
self._range_table = copy.deepcopy(self._current_tokens)
|
||||
|
||||
constrain_func = functools.partial(
|
||||
self._constrain_func, context=context)
|
||||
|
||||
self._controller.reset(self._range_table, self._current_tokens,
|
||||
constrain_func)
|
||||
|
||||
def _constrain_func(self, tokens, context=None):
|
||||
"""Check whether the tokens meet constraint."""
|
||||
ori_flops = context.eval_graph.flops()
|
||||
ratios = self._tokens_to_ratios(tokens)
|
||||
params = self._pruned_param_names
|
||||
param_shape_backup = {}
|
||||
self._prune_parameters(
|
||||
context.eval_graph,
|
||||
context.scope,
|
||||
params,
|
||||
ratios,
|
||||
context.place,
|
||||
only_graph=True,
|
||||
param_shape_backup=param_shape_backup)
|
||||
context.eval_graph.update_groups_of_conv()
|
||||
flops = context.eval_graph.flops()
|
||||
for param in param_shape_backup.keys():
|
||||
context.eval_graph.var(param).set_shape(param_shape_backup[param])
|
||||
flops_ratio = (1 - float(flops) / ori_flops)
|
||||
if flops_ratio >= self._min_ratio and flops_ratio <= self._max_ratio:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _get_init_tokens(self, context):
|
||||
"""Get initial tokens.
|
||||
"""
|
||||
ratios = self._get_uniform_ratios(context)
|
||||
return self._ratios_to_tokens(ratios)
|
||||
|
||||
def _ratios_to_tokens(self, ratios):
|
||||
"""Convert pruned ratios to tokens.
|
||||
"""
|
||||
return [int(ratio / 0.01) for ratio in ratios]
|
||||
|
||||
def _tokens_to_ratios(self, tokens):
|
||||
"""Convert tokens to pruned ratios.
|
||||
"""
|
||||
return [token * 0.01 for token in tokens]
|
||||
|
||||
def _get_uniform_ratios(self, context):
|
||||
"""
|
||||
Search a group of uniform ratios.
|
||||
"""
|
||||
min_ratio = 0.
|
||||
max_ratio = 1.
|
||||
target = (self._min_ratio + self._max_ratio) / 2
|
||||
flops = context.eval_graph.flops()
|
||||
model_size = context.eval_graph.numel_params()
|
||||
ratios = None
|
||||
while min_ratio < max_ratio:
|
||||
ratio = (max_ratio + min_ratio) / 2
|
||||
ratios = [ratio] * len(self._pruned_param_names)
|
||||
param_shape_backup = {}
|
||||
self._prune_parameters(
|
||||
context.eval_graph,
|
||||
context.scope,
|
||||
self._pruned_param_names,
|
||||
ratios,
|
||||
context.place,
|
||||
only_graph=True,
|
||||
param_shape_backup=param_shape_backup)
|
||||
|
||||
pruned_flops = 1 - (float(context.eval_graph.flops()) / flops)
|
||||
pruned_size = 1 - (float(context.eval_graph.numel_params()) /
|
||||
model_size)
|
||||
for param in param_shape_backup.keys():
|
||||
context.eval_graph.var(param).set_shape(param_shape_backup[
|
||||
param])
|
||||
|
||||
if abs(pruned_flops - target) < 1e-2:
|
||||
break
|
||||
if pruned_flops > target:
|
||||
max_ratio = ratio
|
||||
else:
|
||||
min_ratio = ratio
|
||||
_logger.info('Get ratios: {}'.format([round(r, 2) for r in ratios]))
|
||||
return ratios
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
"""
|
||||
step 1: Get a new tokens from controller.
|
||||
step 2: Pruning eval_graph and optimize_program by tokens
|
||||
"""
|
||||
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
|
||||
self._retrain_epoch == 0 or
|
||||
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
|
||||
self._current_tokens = self._controller.next_tokens()
|
||||
params = self._pruned_param_names
|
||||
ratios = self._tokens_to_ratios(self._current_tokens)
|
||||
|
||||
self._param_shape_backup = {}
|
||||
self._param_backup = {}
|
||||
self._prune_parameters(
|
||||
context.optimize_graph,
|
||||
context.scope,
|
||||
params,
|
||||
ratios,
|
||||
context.place,
|
||||
param_backup=self._param_backup,
|
||||
param_shape_backup=self._param_shape_backup)
|
||||
self._prune_graph(context.eval_graph, context.optimize_graph)
|
||||
context.optimize_graph.update_groups_of_conv()
|
||||
context.eval_graph.update_groups_of_conv()
|
||||
context.optimize_graph.compile(
|
||||
mem_opt=True) # to update the compiled program
|
||||
context.skip_training = (self._retrain_epoch == 0)
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
"""
|
||||
step 1: Get reward of current tokens and update controller.
|
||||
step 2: Restore eval_graph and optimize_graph
|
||||
"""
|
||||
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
|
||||
self._retrain_epoch == 0 or
|
||||
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
|
||||
reward = context.eval_results[self._metric_name][-1]
|
||||
self._controller.update(self._current_tokens, reward)
|
||||
|
||||
# restore pruned parameters
|
||||
for param_name in self._param_backup.keys():
|
||||
param_t = context.scope.find_var(param_name).get_tensor()
|
||||
param_t.set(self._param_backup[param_name], context.place)
|
||||
self._param_backup = {}
|
||||
# restore shape of parameters
|
||||
for param in self._param_shape_backup.keys():
|
||||
context.optimize_graph.var(param).set_shape(
|
||||
self._param_shape_backup[param])
|
||||
self._param_shape_backup = {}
|
||||
self._prune_graph(context.eval_graph, context.optimize_graph)
|
||||
|
||||
context.optimize_graph.update_groups_of_conv()
|
||||
context.eval_graph.update_groups_of_conv()
|
||||
context.optimize_graph.compile(
|
||||
mem_opt=True) # to update the compiled program
|
||||
|
||||
elif context.epoch_id == self.end_epoch: # restore graph for final training
|
||||
# restore pruned parameters
|
||||
for param_name in self._param_backup.keys():
|
||||
param_t = context.scope.find_var(param_name).get_tensor()
|
||||
param_t.set(self.param_backup[param_name], context.place)
|
||||
# restore shape of parameters
|
||||
for param in self._param_shape_backup.keys():
|
||||
context.eval_graph.var(param).set_shape(
|
||||
self._param_shape_backup[param])
|
||||
context.optimize_graph.var(param).set_shape(
|
||||
self._param_shape_backup[param])
|
||||
|
||||
context.optimize_graph.update_groups_of_conv()
|
||||
context.eval_graph.update_groups_of_conv()
|
||||
|
||||
params, ratios = self._get_prune_ratios(
|
||||
self._controller._best_tokens)
|
||||
self._prune_parameters(context.optimize_graph, context.scope,
|
||||
params, ratios, context.place)
|
||||
|
||||
self._prune_graph(context.eval_graph, context.optimize_graph)
|
||||
context.optimize_graph.update_groups_of_conv()
|
||||
context.eval_graph.update_groups_of_conv()
|
||||
context.optimize_graph.compile(
|
||||
mem_opt=True) # to update the compiled program
|
||||
|
||||
context.skip_training = False
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import controller
|
||||
from .controller import *
|
||||
|
||||
__all__ = controller.__all__
|
@ -0,0 +1,147 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The controller used to search hyperparameters or neural architecture"""
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
import math
|
||||
import logging
|
||||
|
||||
__all__ = ['EvolutionaryController', 'SAController']
|
||||
|
||||
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class EvolutionaryController(object):
|
||||
"""Abstract controller for all evolutionary searching method.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def update(self, tokens, reward):
|
||||
"""Update the status of controller according current tokens and reward.
|
||||
Args:
|
||||
tokens(list<int>): A solution of searching task.
|
||||
reward(list<int>): The reward of tokens.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
def reset(self, range_table, constrain_func=None):
|
||||
"""Reset the controller.
|
||||
Args:
|
||||
range_table(list<int>): It is used to define the searching space of controller.
|
||||
The tokens[i] generated by controller should be in [0, range_table[i]).
|
||||
constrain_func(function): It is used to check whether tokens meet the constraint.
|
||||
None means there is no constraint. Default: None.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
def next_tokens(self):
|
||||
"""Generate new tokens.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
|
||||
class SAController(EvolutionaryController):
|
||||
"""Simulated annealing controller."""
|
||||
|
||||
def __init__(self,
|
||||
range_table=None,
|
||||
reduce_rate=0.85,
|
||||
init_temperature=1024,
|
||||
max_iter_number=300):
|
||||
"""Initialize.
|
||||
Args:
|
||||
range_table(list<int>): Range table.
|
||||
reduce_rate(float): The decay rate of temperature.
|
||||
init_temperature(float): Init temperature.
|
||||
max_iter_number(int): max iteration number.
|
||||
"""
|
||||
super(SAController, self).__init__()
|
||||
self._range_table = range_table
|
||||
self._reduce_rate = reduce_rate
|
||||
self._init_temperature = init_temperature
|
||||
self._max_iter_number = max_iter_number
|
||||
self._reward = -1
|
||||
self._tokens = None
|
||||
self._max_reward = -1
|
||||
self._best_tokens = None
|
||||
self._iter = 0
|
||||
|
||||
def __getstate__(self):
|
||||
d = {}
|
||||
for key in self.__dict__:
|
||||
if key != "_constrain_func":
|
||||
d[key] = self.__dict__[key]
|
||||
return d
|
||||
|
||||
def reset(self, range_table, init_tokens, constrain_func=None):
|
||||
"""
|
||||
Reset the status of current controller.
|
||||
Args:
|
||||
range_table(list<int>): The range of value in each position of tokens generated by current controller. The range of tokens[i] is [0, range_table[i]).
|
||||
init_tokens(list<int>): The initial tokens.
|
||||
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
|
||||
"""
|
||||
self._range_table = range_table
|
||||
self._constrain_func = constrain_func
|
||||
self._tokens = init_tokens
|
||||
self._iter = 0
|
||||
|
||||
def update(self, tokens, reward):
|
||||
"""
|
||||
Update the controller according to latest tokens and reward.
|
||||
Args:
|
||||
tokens(list<int>): The tokens generated in last step.
|
||||
reward(float): The reward of tokens.
|
||||
"""
|
||||
self._iter += 1
|
||||
temperature = self._init_temperature * self._reduce_rate**self._iter
|
||||
if (reward > self._reward) or (np.random.random() <= math.exp(
|
||||
(reward - self._reward) / temperature)):
|
||||
self._reward = reward
|
||||
self._tokens = tokens
|
||||
if reward > self._max_reward:
|
||||
self._max_reward = reward
|
||||
self._best_tokens = tokens
|
||||
_logger.info("iter: {}; max_reward: {}; best_tokens: {}".format(
|
||||
self._iter, self._max_reward, self._best_tokens))
|
||||
_logger.info("current_reward: {}; current tokens: {}".format(
|
||||
self._reward, self._tokens))
|
||||
|
||||
def next_tokens(self):
|
||||
"""
|
||||
Get next tokens.
|
||||
"""
|
||||
tokens = self._tokens
|
||||
new_tokens = tokens[:]
|
||||
index = int(len(self._range_table) * np.random.random())
|
||||
new_tokens[index] = (
|
||||
new_tokens[index] + np.random.randint(self._range_table[index] - 1)
|
||||
+ 1) % self._range_table[index]
|
||||
_logger.info("change index[{}] from {} to {}".format(index, tokens[
|
||||
index], new_tokens[index]))
|
||||
if self._constrain_func is None:
|
||||
return new_tokens
|
||||
for _ in range(self._max_iter_number):
|
||||
if not self._constrain_func(new_tokens):
|
||||
index = int(len(self._range_table) * np.random.random())
|
||||
new_tokens = tokens[:]
|
||||
new_tokens[index] = np.random.randint(self._range_table[index])
|
||||
else:
|
||||
break
|
||||
return new_tokens
|
@ -0,0 +1,30 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'StructurePruner'
|
||||
pruning_axis:
|
||||
'*': 0
|
||||
criterions:
|
||||
'*': 'l1_norm'
|
||||
controllers:
|
||||
sa_controller:
|
||||
class: 'SAController'
|
||||
reduce_rate: 0.9
|
||||
init_temperature: 1024
|
||||
max_iter_number: 300
|
||||
strategies:
|
||||
auto_pruning_strategy:
|
||||
class: 'AutoPruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
controller: 'sa_controller'
|
||||
start_epoch: 0
|
||||
end_epoch: 2
|
||||
max_ratio: 0.7
|
||||
min_ratio: 0.5
|
||||
pruned_params: '.*_sep_weights'
|
||||
metric_name: 'acc_top5'
|
||||
compressor:
|
||||
epoch: 2
|
||||
checkpoint_path: './checkpoints_auto_pruning/'
|
||||
strategies:
|
||||
- auto_pruning_strategy
|
@ -0,0 +1,22 @@
|
||||
version: 1.0
|
||||
controllers:
|
||||
sa_controller:
|
||||
class: 'SAController'
|
||||
reduce_rate: 0.9
|
||||
init_temperature: 1024
|
||||
max_iter_number: 300
|
||||
strategies:
|
||||
light_nas_strategy:
|
||||
class: 'LightNASStrategy'
|
||||
controller: 'sa_controller'
|
||||
target_flops: 629145600
|
||||
end_epoch: 2
|
||||
retrain_epoch: 1
|
||||
metric_name: 'acc_top1'
|
||||
is_server: 1
|
||||
max_client_num: 100
|
||||
search_steps: 2
|
||||
compressor:
|
||||
epoch: 2
|
||||
strategies:
|
||||
- light_nas_strategy
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,86 @@
|
||||
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||
#
|
||||
# licensed under the apache license, version 2.0 (the "license");
|
||||
# you may not use this file except in compliance with the license.
|
||||
# you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the license is distributed on an "as is" basis,
|
||||
# without warranties or conditions of any kind, either express or implied.
|
||||
# see the license for the specific language governing permissions and
|
||||
# limitations under the license.
|
||||
|
||||
import paddle
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
from mobilenet import MobileNet
|
||||
from paddle.fluid.contrib.slim.core import Compressor
|
||||
from paddle.fluid.contrib.slim.graph import GraphWrapper
|
||||
|
||||
|
||||
class TestFilterPruning(unittest.TestCase):
|
||||
def test_compression(self):
|
||||
"""
|
||||
Model: mobilenet_v1
|
||||
data: mnist
|
||||
step1: Training one epoch
|
||||
step2: pruning flops
|
||||
step3: fine-tune one epoch
|
||||
step4: check top1_acc.
|
||||
"""
|
||||
if not fluid.core.is_compiled_with_cuda():
|
||||
return
|
||||
class_dim = 10
|
||||
image_shape = [1, 28, 28]
|
||||
image = fluid.layers.data(
|
||||
name='image', shape=image_shape, dtype='float32')
|
||||
image.stop_gradient = False
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
out = MobileNet("auto_pruning").net(input=image, class_dim=class_dim)
|
||||
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
|
||||
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
|
||||
val_program = fluid.default_main_program().clone(for_test=False)
|
||||
|
||||
cost = fluid.layers.cross_entropy(input=out, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
optimizer = fluid.optimizer.Momentum(
|
||||
momentum=0.9,
|
||||
learning_rate=0.01,
|
||||
regularization=fluid.regularizer.L2Decay(4e-5))
|
||||
|
||||
place = fluid.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
|
||||
|
||||
val_feed_list = [('img', image.name), ('label', label.name)]
|
||||
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5',
|
||||
acc_top5.name)]
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.mnist.train(), batch_size=128)
|
||||
train_feed_list = [('img', image.name), ('label', label.name)]
|
||||
train_fetch_list = [('loss', avg_cost.name)]
|
||||
|
||||
com_pass = Compressor(
|
||||
place,
|
||||
fluid.global_scope(),
|
||||
fluid.default_main_program(),
|
||||
train_reader=train_reader,
|
||||
train_feed_list=train_feed_list,
|
||||
train_fetch_list=train_fetch_list,
|
||||
eval_program=val_program,
|
||||
eval_reader=val_reader,
|
||||
eval_feed_list=val_feed_list,
|
||||
eval_fetch_list=val_fetch_list,
|
||||
train_optimizer=optimizer)
|
||||
com_pass.config('./auto_pruning/compress.yaml')
|
||||
eval_graph = com_pass.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,66 @@
|
||||
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||
#
|
||||
# licensed under the apache license, version 2.0 (the "license");
|
||||
# you may not use this file except in compliance with the license.
|
||||
# you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the license is distributed on an "as is" basis,
|
||||
# without warranties or conditions of any kind, either express or implied.
|
||||
# see the license for the specific language governing permissions and
|
||||
# limitations under the license.
|
||||
|
||||
import paddle
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
from mobilenet import MobileNet
|
||||
from paddle.fluid.contrib.slim.core import Compressor
|
||||
from paddle.fluid.contrib.slim.graph import GraphWrapper
|
||||
import sys
|
||||
sys.path.append("./light_nas")
|
||||
from light_nas_space import LightNASSpace
|
||||
|
||||
|
||||
class TestLightNAS(unittest.TestCase):
|
||||
def test_compression(self):
|
||||
if not fluid.core.is_compiled_with_cuda():
|
||||
return
|
||||
class_dim = 10
|
||||
image_shape = [1, 28, 28]
|
||||
|
||||
space = LightNASSpace()
|
||||
|
||||
startup_prog, train_prog, test_prog, train_metrics, test_metrics, train_reader, test_reader = space.create_net(
|
||||
)
|
||||
train_cost, train_acc1, train_acc5, global_lr = train_metrics
|
||||
test_cost, test_acc1, test_acc5 = test_metrics
|
||||
|
||||
place = fluid.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_prog)
|
||||
|
||||
val_fetch_list = [('acc_top1', test_acc1.name), ('acc_top5',
|
||||
test_acc5.name)]
|
||||
train_fetch_list = [('loss', train_cost.name)]
|
||||
|
||||
com_pass = Compressor(
|
||||
place,
|
||||
fluid.global_scope(),
|
||||
train_prog,
|
||||
train_reader=train_reader,
|
||||
train_feed_list=None,
|
||||
train_fetch_list=train_fetch_list,
|
||||
eval_program=test_prog,
|
||||
eval_reader=test_reader,
|
||||
eval_feed_list=None,
|
||||
eval_fetch_list=val_fetch_list,
|
||||
train_optimizer=None,
|
||||
search_space=space)
|
||||
com_pass.config('./light_nas/compress.yaml')
|
||||
eval_graph = com_pass.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue