|
|
|
@ -20,6 +20,7 @@ from .... import profiler
|
|
|
|
|
from .... import scope_guard
|
|
|
|
|
from ....data_feeder import DataFeeder
|
|
|
|
|
from ....log_helper import get_logger
|
|
|
|
|
from ....reader import PyReader
|
|
|
|
|
from ..graph import *
|
|
|
|
|
from .config import ConfigFactory
|
|
|
|
|
import numpy as np
|
|
|
|
@ -185,10 +186,16 @@ class Context(object):
|
|
|
|
|
s_time = time.time()
|
|
|
|
|
reader = self.eval_reader
|
|
|
|
|
if sampled_rate:
|
|
|
|
|
assert (not isinstance(reader, Variable))
|
|
|
|
|
assert (sampled_rate > 0)
|
|
|
|
|
assert (self.cache_path is not None)
|
|
|
|
|
_logger.info('sampled_rate: {}; cached_id: {}'.format(sampled_rate,
|
|
|
|
|
cached_id))
|
|
|
|
|
reader = cached_reader(reader, sampled_rate, self.cache_path,
|
|
|
|
|
cached_id)
|
|
|
|
|
|
|
|
|
|
if isinstance(reader, Variable):
|
|
|
|
|
if isinstance(reader, Variable) or (isinstance(reader, PyReader) and
|
|
|
|
|
(not reader._iterable)):
|
|
|
|
|
reader.start()
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
@ -249,7 +256,8 @@ class Compressor(object):
|
|
|
|
|
checkpoint_path=None,
|
|
|
|
|
train_optimizer=None,
|
|
|
|
|
distiller_optimizer=None,
|
|
|
|
|
search_space=None):
|
|
|
|
|
search_space=None,
|
|
|
|
|
log_period=20):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
place(fluid.Place): The device place where the compression job running.
|
|
|
|
@ -294,6 +302,7 @@ class Compressor(object):
|
|
|
|
|
student-net in fine-tune stage.
|
|
|
|
|
search_space(slim.nas.SearchSpace): The instance that define the searching space. It must inherite
|
|
|
|
|
slim.nas.SearchSpace class and overwrite the abstract methods.
|
|
|
|
|
log_period(int): The period of print log of training.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
assert train_feed_list is None or isinstance(
|
|
|
|
@ -329,6 +338,8 @@ class Compressor(object):
|
|
|
|
|
self.init_model = None
|
|
|
|
|
|
|
|
|
|
self.search_space = search_space
|
|
|
|
|
self.log_period = log_period
|
|
|
|
|
assert (log_period > 0)
|
|
|
|
|
|
|
|
|
|
def _add_strategy(self, strategy):
|
|
|
|
|
"""
|
|
|
|
@ -357,6 +368,7 @@ class Compressor(object):
|
|
|
|
|
|
|
|
|
|
if 'eval_epoch' in factory.compressor:
|
|
|
|
|
self.eval_epoch = factory.compressor['eval_epoch']
|
|
|
|
|
assert (self.eval_epoch > 0)
|
|
|
|
|
|
|
|
|
|
def _init_model(self, context):
|
|
|
|
|
"""
|
|
|
|
@ -414,7 +426,7 @@ class Compressor(object):
|
|
|
|
|
else:
|
|
|
|
|
strategies = pickle.load(
|
|
|
|
|
strategy_file, encoding='bytes')
|
|
|
|
|
|
|
|
|
|
assert (len(self.strategies) == len(strategies))
|
|
|
|
|
for s, s1 in zip(self.strategies, strategies):
|
|
|
|
|
s1.__dict__.update(s.__dict__)
|
|
|
|
|
|
|
|
|
@ -472,7 +484,9 @@ class Compressor(object):
|
|
|
|
|
context.optimize_graph.program).with_data_parallel(
|
|
|
|
|
loss_name=context.optimize_graph.out_nodes['loss'])
|
|
|
|
|
|
|
|
|
|
if isinstance(context.train_reader, Variable):
|
|
|
|
|
if isinstance(context.train_reader, Variable) or (
|
|
|
|
|
isinstance(context.train_reader,
|
|
|
|
|
PyReader) and (not context.train_reader._iterable)):
|
|
|
|
|
context.train_reader.start()
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
@ -482,7 +496,7 @@ class Compressor(object):
|
|
|
|
|
results = executor.run(context.optimize_graph,
|
|
|
|
|
context.scope)
|
|
|
|
|
results = [float(np.mean(result)) for result in results]
|
|
|
|
|
if context.batch_id % 20 == 0:
|
|
|
|
|
if context.batch_id % self.log_period == 0:
|
|
|
|
|
_logger.info("epoch:{}; batch_id:{}; {} = {}".format(
|
|
|
|
|
context.epoch_id, context.batch_id,
|
|
|
|
|
context.optimize_graph.out_nodes.keys(
|
|
|
|
@ -502,7 +516,7 @@ class Compressor(object):
|
|
|
|
|
context.scope,
|
|
|
|
|
data=data)
|
|
|
|
|
results = [float(np.mean(result)) for result in results]
|
|
|
|
|
if context.batch_id % 20 == 0:
|
|
|
|
|
if context.batch_id % self.log_period == 0:
|
|
|
|
|
_logger.info("epoch:{}; batch_id:{}; {} = {}".format(
|
|
|
|
|
context.epoch_id, context.batch_id,
|
|
|
|
|
context.optimize_graph.out_nodes.keys(
|
|
|
|
|