Remove slim from paddle framework (#25666)
* Remove slim from paddle framework test=develop Co-authored-by: wanghaoshuang <wanghaoshuang@baidu.com>fix_copy_if_different
parent
bca303165a
commit
2131559d08
@ -1,22 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import config
|
||||
from .config import *
|
||||
from . import compressor
|
||||
from .compressor import *
|
||||
from . import strategy
|
||||
from .strategy import *
|
||||
|
||||
__all__ = config.__all__ + compressor.__all__ + strategy.__all__
|
File diff suppressed because it is too large
Load Diff
@ -1,130 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import funcsigs
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
from ..prune import *
|
||||
from ..quantization import *
|
||||
from .strategy import *
|
||||
from ..distillation import *
|
||||
from ..searcher import *
|
||||
from ..nas import *
|
||||
|
||||
__all__ = ['ConfigFactory']
|
||||
"""This factory is used to create instances by loading and parsing configure file with yaml format.
|
||||
"""
|
||||
|
||||
PLUGINS = ['pruners', 'quantizers', 'distillers', 'strategies', 'controllers']
|
||||
|
||||
|
||||
class ConfigFactory(object):
|
||||
def __init__(self, config):
|
||||
"""Init a factory from configure file."""
|
||||
self.instances = {}
|
||||
self.compressor = {}
|
||||
self.version = None
|
||||
self._parse_config(config)
|
||||
|
||||
def instance(self, name):
|
||||
"""
|
||||
Get instance from factory.
|
||||
"""
|
||||
if name in self.instances:
|
||||
return self.instances[name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _new_instance(self, name, attrs):
|
||||
if name not in self.instances:
|
||||
class_ = globals()[attrs['class']]
|
||||
sig = funcsigs.signature(class_.__init__)
|
||||
keys = [
|
||||
param.name for param in sig.parameters.values()
|
||||
if (param.kind == param.POSITIONAL_OR_KEYWORD)
|
||||
][1:]
|
||||
keys = set(attrs.keys()).intersection(set(keys))
|
||||
args = {}
|
||||
for key in keys:
|
||||
value = attrs[key]
|
||||
if isinstance(value, str) and value.lower() == 'none':
|
||||
value = None
|
||||
if isinstance(value, str) and value in self.instances:
|
||||
value = self.instances[value]
|
||||
if isinstance(value, list):
|
||||
for i in range(len(value)):
|
||||
if isinstance(value[i],
|
||||
str) and value[i] in self.instances:
|
||||
value[i] = self.instances[value[i]]
|
||||
|
||||
args[key] = value
|
||||
self.instances[name] = class_(**args)
|
||||
return self.instances.get(name)
|
||||
|
||||
def _parse_config(self, config):
|
||||
assert config
|
||||
with open(config, 'r') as config_file:
|
||||
key_values = self._ordered_load(config_file)
|
||||
for key in key_values:
|
||||
# parse version
|
||||
if key == 'version' and self.version is None:
|
||||
self.version = int(key_values['version'])
|
||||
assert self.version == int(key_values['version'])
|
||||
|
||||
# parse pruners
|
||||
if key in PLUGINS:
|
||||
instances = key_values[key]
|
||||
for name in instances:
|
||||
self._new_instance(name, instances[name])
|
||||
|
||||
if key == 'compressor':
|
||||
self.compressor['strategies'] = []
|
||||
self.compressor['epoch'] = key_values[key]['epoch']
|
||||
if 'init_model' in key_values[key]:
|
||||
self.compressor['init_model'] = key_values[key][
|
||||
'init_model']
|
||||
if 'checkpoint_path' in key_values[key]:
|
||||
self.compressor['checkpoint_path'] = key_values[key][
|
||||
'checkpoint_path']
|
||||
if 'eval_epoch' in key_values[key]:
|
||||
self.compressor['eval_epoch'] = key_values[key][
|
||||
'eval_epoch']
|
||||
if 'strategies' in key_values[key]:
|
||||
for name in key_values[key]['strategies']:
|
||||
strategy = self.instance(name)
|
||||
self.compressor['strategies'].append(strategy)
|
||||
|
||||
if key == 'include':
|
||||
for config_file in key_values[key]:
|
||||
self._parse_config(config_file.strip())
|
||||
|
||||
def _ordered_load(self,
|
||||
stream,
|
||||
Loader=yaml.Loader,
|
||||
object_pairs_hook=OrderedDict):
|
||||
"""
|
||||
See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts
|
||||
"""
|
||||
|
||||
class OrderedLoader(Loader):
|
||||
pass
|
||||
|
||||
def construct_mapping(loader, node):
|
||||
loader.flatten_mapping(node)
|
||||
return object_pairs_hook(loader.construct_pairs(node))
|
||||
|
||||
OrderedLoader.add_constructor(
|
||||
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
|
||||
return yaml.load(stream, OrderedLoader)
|
@ -1,58 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['Strategy']
|
||||
|
||||
|
||||
class Strategy(object):
|
||||
"""
|
||||
Base class for all strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, start_epoch=0, end_epoch=0):
|
||||
"""
|
||||
Args:
|
||||
start_epoch: The first epoch to apply the strategy.
|
||||
end_epoch: The last epoch to apply the strategy.
|
||||
"""
|
||||
self.start_epoch = start_epoch
|
||||
self.end_epoch = end_epoch
|
||||
|
||||
def __getstate__(self):
|
||||
d = {}
|
||||
for key in self.__dict__:
|
||||
if key not in ["start_epoch", "end_epoch"]:
|
||||
d[key] = self.__dict__[key]
|
||||
return d
|
||||
|
||||
def on_compression_begin(self, context):
|
||||
pass
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
pass
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
pass
|
||||
|
||||
def on_batch_begin(self, context):
|
||||
pass
|
||||
|
||||
def on_batch_end(self, context):
|
||||
pass
|
||||
|
||||
def on_compression_end(self, context):
|
||||
pass
|
||||
|
||||
def restore_from_checkpoint(self, context):
|
||||
pass
|
@ -1,21 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import distiller
|
||||
from .distiller import *
|
||||
from . import distillation_strategy
|
||||
from .distillation_strategy import *
|
||||
|
||||
__all__ = distiller.__all__
|
||||
__all__ += distillation_strategy.__all__
|
@ -1,104 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..core.strategy import Strategy
|
||||
from ....framework import Program, Variable, program_guard
|
||||
from ....log_helper import get_logger
|
||||
from .... import Executor
|
||||
import logging
|
||||
|
||||
__all__ = ['DistillationStrategy']
|
||||
|
||||
_logger = get_logger(
|
||||
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
|
||||
|
||||
|
||||
class DistillationStrategy(Strategy):
|
||||
def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
|
||||
"""
|
||||
Args:
|
||||
distillers(list): A list of distiller used to combine student graph and teacher graph
|
||||
by adding some loss.
|
||||
start_epoch(int): The epoch when to merge student graph and teacher graph for
|
||||
distillation training. default: 0
|
||||
end_epoch(int): The epoch when to finish distillation training. default: 0
|
||||
|
||||
"""
|
||||
super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
|
||||
self.distillers = distillers
|
||||
|
||||
def restore_from_checkpoint(self, context):
|
||||
# load from checkpoint
|
||||
if context.epoch_id > 0:
|
||||
if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:
|
||||
_logger.info('Restore DistillationStrategy')
|
||||
self._create_distillation_graph(context)
|
||||
_logger.info('Restore DistillationStrategy finish.')
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
if self.start_epoch == context.epoch_id:
|
||||
_logger.info('DistillationStrategy::on_epoch_begin.')
|
||||
self._create_distillation_graph(context)
|
||||
_logger.info('DistillationStrategy set optimize_graph.')
|
||||
|
||||
def _create_distillation_graph(self, context):
|
||||
"""
|
||||
step 1: Merge student graph and teacher graph into distillation graph.
|
||||
step 2: Add loss into distillation graph by distillers.
|
||||
step 3: Append backward ops and optimize ops into distillation graph for training.
|
||||
"""
|
||||
# step 1
|
||||
teacher = context.teacher_graphs[0]
|
||||
for var in teacher.program.list_vars():
|
||||
var.stop_gradient = True
|
||||
graph = context.train_graph.clone()
|
||||
graph.merge(teacher)
|
||||
if 'loss' in graph.out_nodes:
|
||||
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
|
||||
|
||||
# step 2
|
||||
for distiller in self.distillers:
|
||||
graph = distiller.distiller_loss(graph)
|
||||
|
||||
# step 3
|
||||
startup_program = Program()
|
||||
with program_guard(graph.program, startup_program):
|
||||
context.distiller_optimizer._name = 'distillation_optimizer'
|
||||
|
||||
# The learning rate variable may be created in other program.
|
||||
# Update information in optimizer to make
|
||||
# learning rate variable being accessible in current program.
|
||||
optimizer = context.distiller_optimizer
|
||||
if isinstance(optimizer._learning_rate, Variable):
|
||||
optimizer._learning_rate_map[
|
||||
graph.program] = optimizer._learning_rate
|
||||
|
||||
optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)
|
||||
|
||||
exe = Executor(context.place)
|
||||
exe.run(startup_program, scope=context.scope)
|
||||
|
||||
# backup graph for fine-tune after distillation
|
||||
context.put('distillation_backup_optimize_graph',
|
||||
context.optimize_graph)
|
||||
context.optimize_graph = graph
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
if context.epoch_id == (self.end_epoch - 1):
|
||||
_logger.info('DistillationStrategy::on_epoch_end.')
|
||||
# restore optimize_graph for fine-tune or other strategy in next stage.
|
||||
context.optimize_graph = context.get(
|
||||
'distillation_backup_optimize_graph')
|
||||
_logger.info(
|
||||
'DistillationStrategy set context.optimize_graph to None.')
|
File diff suppressed because it is too large
Load Diff
@ -1,20 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import executor
|
||||
from .executor import *
|
||||
from . import graph_wrapper
|
||||
from .graph_wrapper import *
|
||||
__all__ = executor.__all__
|
||||
__all__ += graph_wrapper.__all__
|
@ -1,61 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ....compiler import CompiledProgram
|
||||
from ....data_feeder import DataFeeder
|
||||
from .... import executor
|
||||
from .graph_wrapper import GraphWrapper
|
||||
|
||||
__all__ = ['SlimGraphExecutor']
|
||||
|
||||
|
||||
class SlimGraphExecutor(object):
|
||||
"""
|
||||
Wrapper of executor used to run GraphWrapper.
|
||||
"""
|
||||
|
||||
def __init__(self, place):
|
||||
self.exe = executor.Executor(place)
|
||||
self.place = place
|
||||
|
||||
def run(self, graph, scope, data=None):
|
||||
"""
|
||||
Runing a graph with a batch of data.
|
||||
Args:
|
||||
graph(GraphWrapper): The graph to be executed.
|
||||
scope(fluid.core.Scope): The scope to be used.
|
||||
data(list<tuple>): A batch of data. Each tuple in this list is a sample.
|
||||
It will feed the items of tuple to the in_nodes of graph.
|
||||
Returns:
|
||||
results(list): A list of result with the same order indicated by graph.out_nodes.
|
||||
"""
|
||||
assert isinstance(graph, GraphWrapper)
|
||||
feed = None
|
||||
if data is not None and isinstance(data[0], dict):
|
||||
# return list = False
|
||||
feed = data
|
||||
elif data is not None:
|
||||
feeder = DataFeeder(
|
||||
feed_list=list(graph.in_nodes.values()),
|
||||
place=self.place,
|
||||
program=graph.program)
|
||||
feed = feeder.feed(data)
|
||||
|
||||
fetch_list = list(graph.out_nodes.values())
|
||||
program = graph.compiled_graph if graph.compiled_graph else graph.program
|
||||
results = self.exe.run(program,
|
||||
scope=scope,
|
||||
fetch_list=fetch_list,
|
||||
feed=feed)
|
||||
return results
|
File diff suppressed because it is too large
Load Diff
@ -1,29 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import light_nas_strategy
|
||||
from .light_nas_strategy import *
|
||||
from . import controller_server
|
||||
from .controller_server import *
|
||||
from . import search_agent
|
||||
from .search_agent import *
|
||||
from . import search_space
|
||||
from .search_space import *
|
||||
from . import lock
|
||||
from .lock import *
|
||||
|
||||
__all__ = light_nas_strategy.__all__
|
||||
__all__ += controller_server.__all__
|
||||
__all__ += search_agent.__all__
|
||||
__all__ += search_space.__all__
|
@ -1,107 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import socket
|
||||
from threading import Thread
|
||||
from ....log_helper import get_logger
|
||||
|
||||
__all__ = ['ControllerServer']
|
||||
|
||||
_logger = get_logger(
|
||||
__name__,
|
||||
logging.INFO,
|
||||
fmt='ControllerServer-%(asctime)s-%(levelname)s: %(message)s')
|
||||
|
||||
|
||||
class ControllerServer(object):
|
||||
"""
|
||||
The controller wrapper with a socket server to handle the request of search agent.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
controller=None,
|
||||
address=('', 0),
|
||||
max_client_num=100,
|
||||
search_steps=None,
|
||||
key=None):
|
||||
"""
|
||||
Args:
|
||||
controller(slim.searcher.Controller): The controller used to generate tokens.
|
||||
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
|
||||
which means setting ip automatically
|
||||
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
|
||||
search_steps(int): The total steps of searching. None means never stopping. Default: None
|
||||
"""
|
||||
self._controller = controller
|
||||
self._address = address
|
||||
self._max_client_num = max_client_num
|
||||
self._search_steps = search_steps
|
||||
self._closed = False
|
||||
self._port = address[1]
|
||||
self._ip = address[0]
|
||||
self._key = key
|
||||
|
||||
def start(self):
|
||||
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket_server.bind(self._address)
|
||||
self._socket_server.listen(self._max_client_num)
|
||||
self._port = self._socket_server.getsockname()[1]
|
||||
self._ip = self._socket_server.getsockname()[0]
|
||||
_logger.info("listen on: [{}:{}]".format(self._ip, self._port))
|
||||
thread = Thread(target=self.run)
|
||||
thread.start()
|
||||
return str(thread)
|
||||
|
||||
def close(self):
|
||||
"""Close the server."""
|
||||
self._closed = True
|
||||
|
||||
def port(self):
|
||||
"""Get the port."""
|
||||
return self._port
|
||||
|
||||
def ip(self):
|
||||
"""Get the ip."""
|
||||
return self._ip
|
||||
|
||||
def run(self):
|
||||
_logger.info("Controller Server run...")
|
||||
while ((self._search_steps is None) or
|
||||
(self._controller._iter <
|
||||
(self._search_steps))) and not self._closed:
|
||||
conn, addr = self._socket_server.accept()
|
||||
message = conn.recv(1024).decode()
|
||||
if message.strip("\n") == "next_tokens":
|
||||
tokens = self._controller.next_tokens()
|
||||
tokens = ",".join([str(token) for token in tokens])
|
||||
conn.send(tokens.encode())
|
||||
else:
|
||||
_logger.info("recv message from {}: [{}]".format(addr, message))
|
||||
messages = message.strip('\n').split("\t")
|
||||
if (len(messages) < 3) or (messages[0] != self._key):
|
||||
_logger.info("recv noise from {}: [{}]".format(addr,
|
||||
message))
|
||||
continue
|
||||
tokens = messages[1]
|
||||
reward = messages[2]
|
||||
tokens = [int(token) for token in tokens.split(",")]
|
||||
self._controller.update(tokens, float(reward))
|
||||
tokens = self._controller.next_tokens()
|
||||
tokens = ",".join([str(token) for token in tokens])
|
||||
conn.send(tokens.encode())
|
||||
_logger.info("send message to {}: [{}]".format(addr, tokens))
|
||||
conn.close()
|
||||
self._socket_server.close()
|
||||
_logger.info("server closed!")
|
@ -1,196 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..core.strategy import Strategy
|
||||
from ..graph import GraphWrapper
|
||||
from .controller_server import ControllerServer
|
||||
from .search_agent import SearchAgent
|
||||
from ....executor import Executor
|
||||
from ....log_helper import get_logger
|
||||
import re
|
||||
import logging
|
||||
import functools
|
||||
import socket
|
||||
from .lock import lock, unlock
|
||||
|
||||
__all__ = ['LightNASStrategy']
|
||||
|
||||
_logger = get_logger(
|
||||
__name__,
|
||||
logging.INFO,
|
||||
fmt='LightNASStrategy-%(asctime)s-%(levelname)s: %(message)s')
|
||||
|
||||
|
||||
class LightNASStrategy(Strategy):
|
||||
"""
|
||||
Light-NAS search strategy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
controller=None,
|
||||
end_epoch=1000,
|
||||
target_flops=629145600,
|
||||
target_latency=0,
|
||||
retrain_epoch=1,
|
||||
metric_name='top1_acc',
|
||||
server_ip=None,
|
||||
server_port=0,
|
||||
is_server=False,
|
||||
max_client_num=100,
|
||||
search_steps=None,
|
||||
key="light-nas"):
|
||||
"""
|
||||
Args:
|
||||
controller(searcher.Controller): The searching controller. Default: None.
|
||||
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
|
||||
target_flops(int): The constraint of FLOPS.
|
||||
target_latency(float): The constraint of latency.
|
||||
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
|
||||
metric_name(str): The metric used to evaluate the model.
|
||||
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
|
||||
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
|
||||
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
|
||||
is_server(bool): Whether current host is controller server. Default: False.
|
||||
max_client_num(int): The maximum number of clients that connect to controller server concurrently. Default: 100.
|
||||
search_steps(int): The total steps of searching. Default: None.
|
||||
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
|
||||
"""
|
||||
self.start_epoch = 0
|
||||
self.end_epoch = end_epoch
|
||||
self._max_flops = target_flops
|
||||
self._max_latency = target_latency
|
||||
self._metric_name = metric_name
|
||||
self._controller = controller
|
||||
self._retrain_epoch = 0
|
||||
self._server_ip = server_ip
|
||||
self._server_port = server_port
|
||||
self._is_server = is_server
|
||||
self._retrain_epoch = retrain_epoch
|
||||
self._search_steps = search_steps
|
||||
self._max_client_num = max_client_num
|
||||
self._max_try_times = 100
|
||||
self._key = key
|
||||
|
||||
if self._server_ip is None:
|
||||
self._server_ip = self._get_host_ip()
|
||||
|
||||
def _get_host_ip(self):
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
def on_compression_begin(self, context):
|
||||
self._current_tokens = context.search_space.init_tokens()
|
||||
self._controller.reset(context.search_space.range_table(),
|
||||
self._current_tokens, None)
|
||||
|
||||
# create controller server
|
||||
if self._is_server:
|
||||
open("./slim_LightNASStrategy_controller_server.socket",
|
||||
'a').close()
|
||||
socket_file = open(
|
||||
"./slim_LightNASStrategy_controller_server.socket", 'r+')
|
||||
lock(socket_file)
|
||||
tid = socket_file.readline()
|
||||
if tid == '':
|
||||
_logger.info("start controller server...")
|
||||
self._server = ControllerServer(
|
||||
controller=self._controller,
|
||||
address=(self._server_ip, self._server_port),
|
||||
max_client_num=self._max_client_num,
|
||||
search_steps=self._search_steps,
|
||||
key=self._key)
|
||||
tid = self._server.start()
|
||||
self._server_port = self._server.port()
|
||||
socket_file.write(tid)
|
||||
_logger.info("started controller server...")
|
||||
unlock(socket_file)
|
||||
socket_file.close()
|
||||
_logger.info("self._server_ip: {}; self._server_port: {}".format(
|
||||
self._server_ip, self._server_port))
|
||||
# create client
|
||||
self._search_agent = SearchAgent(
|
||||
self._server_ip, self._server_port, key=self._key)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Socket can't be pickled."""
|
||||
d = {}
|
||||
for key in self.__dict__:
|
||||
if key not in ["_search_agent", "_server"]:
|
||||
d[key] = self.__dict__[key]
|
||||
return d
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
|
||||
self._retrain_epoch == 0 or
|
||||
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
|
||||
_logger.info("light nas strategy on_epoch_begin")
|
||||
min_flops = -1
|
||||
for _ in range(self._max_try_times):
|
||||
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
|
||||
self._current_tokens)
|
||||
context.eval_graph.program = test_p
|
||||
flops = context.eval_graph.flops()
|
||||
if min_flops == -1:
|
||||
min_flops = flops
|
||||
min_tokens = self._current_tokens[:]
|
||||
else:
|
||||
if flops < min_flops:
|
||||
min_tokens = self._current_tokens[:]
|
||||
if self._max_latency > 0:
|
||||
latency = context.search_space.get_model_latency(test_p)
|
||||
_logger.info("try [{}] with latency {} flops {}".format(
|
||||
self._current_tokens, latency, flops))
|
||||
else:
|
||||
_logger.info("try [{}] with flops {}".format(
|
||||
self._current_tokens, flops))
|
||||
if flops > self._max_flops or (self._max_latency > 0 and
|
||||
latency > self._max_latency):
|
||||
self._current_tokens = self._controller.next_tokens(
|
||||
min_tokens)
|
||||
else:
|
||||
break
|
||||
|
||||
context.train_reader = train_reader
|
||||
context.eval_reader = test_reader
|
||||
|
||||
exe = Executor(context.place)
|
||||
exe.run(startup_p)
|
||||
|
||||
context.optimize_graph.program = train_p
|
||||
context.optimize_graph.compile()
|
||||
|
||||
context.skip_training = (self._retrain_epoch == 0)
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
|
||||
self._retrain_epoch == 0 or
|
||||
(context.epoch_id - self.start_epoch + 1
|
||||
) % self._retrain_epoch == 0):
|
||||
|
||||
self._current_reward = context.eval_results[self._metric_name][-1]
|
||||
flops = context.eval_graph.flops()
|
||||
if flops > self._max_flops:
|
||||
self._current_reward = 0.0
|
||||
if self._max_latency > 0:
|
||||
test_p = context.search_space.create_net(self._current_tokens)[
|
||||
2]
|
||||
latency = context.search_space.get_model_latency(test_p)
|
||||
if latency > self._max_latency:
|
||||
self._current_reward = 0.0
|
||||
_logger.info("reward: {}; latency: {}; flops: {}; tokens: {}".
|
||||
format(self._current_reward, latency, flops,
|
||||
self._current_tokens))
|
||||
else:
|
||||
_logger.info("reward: {}; flops: {}; tokens: {}".format(
|
||||
self._current_reward, flops, self._current_tokens))
|
||||
self._current_tokens = self._search_agent.update(
|
||||
self._current_tokens, self._current_reward)
|
@ -1,36 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
__All__ = ['lock', 'unlock']
|
||||
if os.name == 'nt':
|
||||
|
||||
def lock(file):
|
||||
raise NotImplementedError('Windows is not supported.')
|
||||
|
||||
def unlock(file):
|
||||
raise NotImplementedError('Windows is not supported.')
|
||||
|
||||
elif os.name == 'posix':
|
||||
from fcntl import flock, LOCK_EX, LOCK_UN
|
||||
|
||||
def lock(file):
|
||||
"""Lock the file in local file system."""
|
||||
flock(file.fileno(), LOCK_EX)
|
||||
|
||||
def unlock(file):
|
||||
"""Unlock the file in local file system."""
|
||||
flock(file.fileno(), LOCK_UN)
|
||||
else:
|
||||
raise RuntimeError("File Locker only support NT and Posix platforms!")
|
@ -1,67 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import socket
|
||||
from ....log_helper import get_logger
|
||||
|
||||
__all__ = ['SearchAgent']
|
||||
|
||||
_logger = get_logger(
|
||||
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
|
||||
|
||||
|
||||
class SearchAgent(object):
|
||||
"""
|
||||
Search agent.
|
||||
"""
|
||||
|
||||
def __init__(self, server_ip=None, server_port=None, key=None):
|
||||
"""
|
||||
Args:
|
||||
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
|
||||
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
|
||||
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
|
||||
"""
|
||||
self.server_ip = server_ip
|
||||
self.server_port = server_port
|
||||
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._key = key
|
||||
|
||||
def update(self, tokens, reward):
|
||||
"""
|
||||
Update the controller according to latest tokens and reward.
|
||||
Args:
|
||||
tokens(list<int>): The tokens generated in last step.
|
||||
reward(float): The reward of tokens.
|
||||
"""
|
||||
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
socket_client.connect((self.server_ip, self.server_port))
|
||||
tokens = ",".join([str(token) for token in tokens])
|
||||
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
|
||||
.encode())
|
||||
tokens = socket_client.recv(1024).decode()
|
||||
tokens = [int(token) for token in tokens.strip("\n").split(",")]
|
||||
return tokens
|
||||
|
||||
def next_tokens(self):
|
||||
"""
|
||||
Get next tokens.
|
||||
"""
|
||||
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
socket_client.connect((self.server_ip, self.server_port))
|
||||
socket_client.send("next_tokens".encode())
|
||||
tokens = socket_client.recv(1024).decode()
|
||||
tokens = [int(token) for token in tokens.strip("\n").split(",")]
|
||||
return tokens
|
@ -1,52 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The search space used to search neural architecture"""
|
||||
|
||||
__all__ = ['SearchSpace']
|
||||
|
||||
|
||||
class SearchSpace(object):
|
||||
"""Controller for Neural Architecture Search.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def init_tokens(self):
|
||||
"""Get init tokens in search space.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
def range_table(self):
|
||||
"""Get range table of current search space.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
def create_net(self, tokens):
|
||||
"""Create networks for training and evaluation according to tokens.
|
||||
Args:
|
||||
tokens(list<int>): The tokens which represent a network.
|
||||
Return:
|
||||
(tuple): startup_program, train_program, evaluation_program, train_metrics, test_metrics
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
||||
|
||||
def get_model_latency(self, program):
|
||||
"""Get model latency according to program.
|
||||
Args:
|
||||
program(Program): The program to get latency.
|
||||
Return:
|
||||
(float): model latency.
|
||||
"""
|
||||
raise NotImplementedError('Abstract method.')
|
@ -1,24 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import pruner
|
||||
from .pruner import *
|
||||
from . import prune_strategy
|
||||
from .prune_strategy import *
|
||||
from . import auto_prune_strategy
|
||||
from .auto_prune_strategy import *
|
||||
|
||||
__all__ = pruner.__all__
|
||||
__all__ += prune_strategy.__all__
|
||||
__all__ += auto_prune_strategy.__all__
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,107 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import collections
|
||||
from .... import layers
|
||||
|
||||
__all__ = ['Pruner', 'StructurePruner']
|
||||
|
||||
|
||||
class Pruner(object):
|
||||
"""
|
||||
Base class of all pruners.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def prune(self, param):
|
||||
pass
|
||||
|
||||
|
||||
class StructurePruner(Pruner):
|
||||
"""
|
||||
Pruner used to pruning parameters by groups.
|
||||
"""
|
||||
|
||||
def __init__(self, pruning_axis, criterions):
|
||||
"""
|
||||
Args:
|
||||
pruning_axis(dict): The key is the name of parameter to be pruned,
|
||||
'*' means all the parameters.
|
||||
The value is the axis to be used. Given a parameter
|
||||
with shape [3, 4], the result of pruning 50% on axis 1
|
||||
is a parameter with shape [3, 2].
|
||||
criterions(dict): The key is the name of parameter to be pruned,
|
||||
'*' means all the parameters.
|
||||
The value is the criterion used to sort groups for pruning.
|
||||
It only supports 'l1_norm' currently.
|
||||
"""
|
||||
self.pruning_axis = pruning_axis
|
||||
self.criterions = criterions
|
||||
|
||||
def cal_pruned_idx(self, name, param, ratio, axis=None):
|
||||
"""
|
||||
Calculate the index to be pruned on axis by given pruning ratio.
|
||||
Args:
|
||||
name(str): The name of parameter to be pruned.
|
||||
param(np.array): The data of parameter to be pruned.
|
||||
ratio(float): The ratio to be pruned.
|
||||
axis(int): The axis to be used for pruning given parameter.
|
||||
If it is None, the value in self.pruning_axis will be used.
|
||||
default: None.
|
||||
Returns:
|
||||
list<int>: The indexes to be pruned on axis.
|
||||
"""
|
||||
criterion = self.criterions[
|
||||
name] if name in self.criterions else self.criterions['*']
|
||||
if axis is None:
|
||||
assert self.pruning_axis is not None, "pruning_axis should set if axis is None."
|
||||
axis = self.pruning_axis[
|
||||
name] if name in self.pruning_axis else self.pruning_axis['*']
|
||||
prune_num = int(round(param.shape[axis] * ratio))
|
||||
reduce_dims = [i for i in range(len(param.shape)) if i != axis]
|
||||
if criterion == 'l1_norm':
|
||||
criterions = np.sum(np.abs(param), axis=tuple(reduce_dims))
|
||||
pruned_idx = criterions.argsort()[:prune_num]
|
||||
return pruned_idx
|
||||
|
||||
def prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
|
||||
"""
|
||||
Pruning a array by indexes on given axis.
|
||||
Args:
|
||||
tensor(numpy.array): The target array to be pruned.
|
||||
pruned_idx(list<int>): The indexes to be pruned.
|
||||
pruned_axis(int): The axis of given array to be pruned on.
|
||||
lazy(bool): True means setting the pruned elements to zero.
|
||||
False means remove the pruned elements from memory.
|
||||
default: False.
|
||||
Returns:
|
||||
numpy.array: The pruned array.
|
||||
"""
|
||||
mask = np.zeros(tensor.shape[pruned_axis], dtype=bool)
|
||||
mask[pruned_idx] = True
|
||||
|
||||
def func(data):
|
||||
return data[~mask]
|
||||
|
||||
def lazy_func(data):
|
||||
data[mask] = 0
|
||||
return data
|
||||
|
||||
if lazy:
|
||||
return np.apply_along_axis(lazy_func, pruned_axis, tensor)
|
||||
else:
|
||||
return np.apply_along_axis(func, pruned_axis, tensor)
|
@ -1,113 +0,0 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import six
|
||||
import numpy as np
|
||||
from .... import core
|
||||
from ..core.strategy import Strategy
|
||||
from ....log_helper import get_logger
|
||||
|
||||
__all__ = ['MKLDNNPostTrainingQuantStrategy']
|
||||
|
||||
_logger = get_logger(
|
||||
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
|
||||
|
||||
|
||||
class MKLDNNPostTrainingQuantStrategy(Strategy):
|
||||
"""
|
||||
The strategy for MKL-DNN Post Training quantization strategy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
int8_model_save_path=None,
|
||||
fp32_model_path=None,
|
||||
cpu_math_library_num_threads=1):
|
||||
"""
|
||||
Args:
|
||||
int8_model_save_path(str): int8_model_save_path is used to save an int8 ProgramDesc
|
||||
with fp32 weights which is used for MKL-DNN int8 inference. For post training quantization,
|
||||
MKLDNNPostTrainingQuantStrategy only supports converting a fp32 ProgramDesc
|
||||
with fp32 weights to an int8 ProgramDesc with fp32 weights now. The saved
|
||||
int8 ProgramDesc with fp32 weights only can be executed with MKL-DNN enabled.
|
||||
None means it doesn't save int8 ProgramDesc with fp32 weights. default: None.
|
||||
fp32_model_path(str): fp32_model_path is used to load an original fp32 ProgramDesc with fp32 weights.
|
||||
None means it doesn't have a fp32 ProgramDesc with fp32 weights. default: None.
|
||||
cpu_math_library_num_threads(int): The number of cpu math library threads which is used on
|
||||
MKLDNNPostTrainingQuantStrategy. 1 means it only uses one cpu math library
|
||||
thread. default: 1
|
||||
"""
|
||||
|
||||
super(MKLDNNPostTrainingQuantStrategy, self).__init__(0, 0)
|
||||
self.int8_model_save_path = int8_model_save_path
|
||||
if fp32_model_path is None:
|
||||
raise Exception("fp32_model_path is None")
|
||||
self.fp32_model_path = fp32_model_path
|
||||
self.cpu_math_library_num_threads = cpu_math_library_num_threads
|
||||
|
||||
def on_compression_begin(self, context):
|
||||
"""
|
||||
Prepare the data and quantify the model
|
||||
"""
|
||||
|
||||
super(MKLDNNPostTrainingQuantStrategy,
|
||||
self).on_compression_begin(context)
|
||||
_logger.info('InferQuantStrategy::on_compression_begin')
|
||||
|
||||
# Prepare the Analysis Config
|
||||
infer_config = core.AnalysisConfig("AnalysisConfig")
|
||||
infer_config.switch_ir_optim(True)
|
||||
infer_config.disable_gpu()
|
||||
infer_config.set_model(self.fp32_model_path)
|
||||
infer_config.enable_mkldnn()
|
||||
infer_config.set_cpu_math_library_num_threads(
|
||||
self.cpu_math_library_num_threads)
|
||||
|
||||
# Prepare the data for calculating the quantization scales
|
||||
warmup_reader = context.eval_reader()
|
||||
if six.PY2:
|
||||
data = warmup_reader.next()
|
||||
|
||||
if six.PY3:
|
||||
data = warmup_reader.__next__()
|
||||
|
||||
num_images = len(data)
|
||||
image_data = [img.tolist() for (img, _) in data]
|
||||
image_data = np.array(image_data).astype("float32").reshape(
|
||||
[num_images, ] + list(data[0][0].shape))
|
||||
image_data = image_data.ravel()
|
||||
images = core.PaddleTensor(image_data, "x")
|
||||
images.shape = [num_images, ] + list(data[0][0].shape)
|
||||
|
||||
label_data = [label for (_, label) in data]
|
||||
labels = core.PaddleTensor(
|
||||
np.array(label_data).astype("int64").reshape([num_images, 1]), "y")
|
||||
|
||||
warmup_data = [images, labels]
|
||||
|
||||
# Enable the INT8 Quantization
|
||||
infer_config.enable_quantizer()
|
||||
infer_config.quantizer_config().set_quant_data(warmup_data)
|
||||
infer_config.quantizer_config().set_quant_batch_size(num_images)
|
||||
|
||||
# Run INT8 MKL-DNN Quantization
|
||||
predictor = core.create_paddle_predictor(infer_config)
|
||||
if self.int8_model_save_path:
|
||||
if not os.path.exists(self.int8_model_save_path):
|
||||
os.makedirs(self.int8_model_save_path)
|
||||
predictor.SaveOptimModel(self.int8_model_save_path)
|
||||
|
||||
_logger.info(
|
||||
'Finish MKLDNNPostTrainingQuantStrategy::on_compresseion_begin')
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue