|
|
@ -40,6 +40,7 @@ class LightNASStrategy(Strategy):
|
|
|
|
controller=None,
|
|
|
|
controller=None,
|
|
|
|
end_epoch=1000,
|
|
|
|
end_epoch=1000,
|
|
|
|
target_flops=629145600,
|
|
|
|
target_flops=629145600,
|
|
|
|
|
|
|
|
target_latency=0,
|
|
|
|
retrain_epoch=1,
|
|
|
|
retrain_epoch=1,
|
|
|
|
metric_name='top1_acc',
|
|
|
|
metric_name='top1_acc',
|
|
|
|
server_ip=None,
|
|
|
|
server_ip=None,
|
|
|
@ -53,6 +54,7 @@ class LightNASStrategy(Strategy):
|
|
|
|
controller(searcher.Controller): The searching controller. Default: None.
|
|
|
|
controller(searcher.Controller): The searching controller. Default: None.
|
|
|
|
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
|
|
|
|
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
|
|
|
|
target_flops(int): The constraint of FLOPS.
|
|
|
|
target_flops(int): The constraint of FLOPS.
|
|
|
|
|
|
|
|
target_latency(float): The constraint of latency.
|
|
|
|
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
|
|
|
|
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
|
|
|
|
metric_name(str): The metric used to evaluate the model.
|
|
|
|
metric_name(str): The metric used to evaluate the model.
|
|
|
|
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
|
|
|
|
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
|
|
|
@ -66,6 +68,7 @@ class LightNASStrategy(Strategy):
|
|
|
|
self.start_epoch = 0
|
|
|
|
self.start_epoch = 0
|
|
|
|
self.end_epoch = end_epoch
|
|
|
|
self.end_epoch = end_epoch
|
|
|
|
self._max_flops = target_flops
|
|
|
|
self._max_flops = target_flops
|
|
|
|
|
|
|
|
self._max_latency = target_latency
|
|
|
|
self._metric_name = metric_name
|
|
|
|
self._metric_name = metric_name
|
|
|
|
self._controller = controller
|
|
|
|
self._controller = controller
|
|
|
|
self._retrain_epoch = 0
|
|
|
|
self._retrain_epoch = 0
|
|
|
@ -86,8 +89,6 @@ class LightNASStrategy(Strategy):
|
|
|
|
|
|
|
|
|
|
|
|
def on_compression_begin(self, context):
|
|
|
|
def on_compression_begin(self, context):
|
|
|
|
self._current_tokens = context.search_space.init_tokens()
|
|
|
|
self._current_tokens = context.search_space.init_tokens()
|
|
|
|
constrain_func = functools.partial(
|
|
|
|
|
|
|
|
self._constrain_func, context=context)
|
|
|
|
|
|
|
|
self._controller.reset(context.search_space.range_table(),
|
|
|
|
self._controller.reset(context.search_space.range_table(),
|
|
|
|
self._current_tokens, None)
|
|
|
|
self._current_tokens, None)
|
|
|
|
|
|
|
|
|
|
|
@ -127,15 +128,6 @@ class LightNASStrategy(Strategy):
|
|
|
|
d[key] = self.__dict__[key]
|
|
|
|
d[key] = self.__dict__[key]
|
|
|
|
return d
|
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
|
|
def _constrain_func(self, tokens, context=None):
|
|
|
|
|
|
|
|
"""Check whether the tokens meet constraint."""
|
|
|
|
|
|
|
|
_, _, test_prog, _, _, _, _ = context.search_space.create_net(tokens)
|
|
|
|
|
|
|
|
flops = GraphWrapper(test_prog).flops()
|
|
|
|
|
|
|
|
if flops <= self._max_flops:
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_epoch_begin(self, context):
|
|
|
|
def on_epoch_begin(self, context):
|
|
|
|
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
|
|
|
|
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
|
|
|
|
self._retrain_epoch == 0 or
|
|
|
|
self._retrain_epoch == 0 or
|
|
|
@ -144,13 +136,20 @@ class LightNASStrategy(Strategy):
|
|
|
|
for _ in range(self._max_try_times):
|
|
|
|
for _ in range(self._max_try_times):
|
|
|
|
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
|
|
|
|
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
|
|
|
|
self._current_tokens)
|
|
|
|
self._current_tokens)
|
|
|
|
_logger.info("try [{}]".format(self._current_tokens))
|
|
|
|
|
|
|
|
context.eval_graph.program = test_p
|
|
|
|
context.eval_graph.program = test_p
|
|
|
|
flops = context.eval_graph.flops()
|
|
|
|
flops = context.eval_graph.flops()
|
|
|
|
if flops <= self._max_flops:
|
|
|
|
if self._max_latency > 0:
|
|
|
|
break
|
|
|
|
latency = context.search_space.get_model_latency(test_p)
|
|
|
|
|
|
|
|
_logger.info("try [{}] with latency {} flops {}".format(
|
|
|
|
|
|
|
|
self._current_tokens, latency, flops))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
_logger.info("try [{}] with flops {}".format(
|
|
|
|
|
|
|
|
self._current_tokens, flops))
|
|
|
|
|
|
|
|
if flops > self._max_flops or (self._max_latency > 0 and
|
|
|
|
|
|
|
|
latency > self._max_latency):
|
|
|
|
self._current_tokens = self._search_agent.next_tokens()
|
|
|
|
self._current_tokens = self._search_agent.next_tokens()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
context.train_reader = train_reader
|
|
|
|
context.train_reader = train_reader
|
|
|
|
context.eval_reader = test_reader
|
|
|
|
context.eval_reader = test_reader
|
|
|
@ -173,7 +172,17 @@ class LightNASStrategy(Strategy):
|
|
|
|
flops = context.eval_graph.flops()
|
|
|
|
flops = context.eval_graph.flops()
|
|
|
|
if flops > self._max_flops:
|
|
|
|
if flops > self._max_flops:
|
|
|
|
self._current_reward = 0.0
|
|
|
|
self._current_reward = 0.0
|
|
|
|
_logger.info("reward: {}; flops: {}; tokens: {}".format(
|
|
|
|
if self._max_latency > 0:
|
|
|
|
self._current_reward, flops, self._current_tokens))
|
|
|
|
test_p = context.search_space.create_net(self._current_tokens)[
|
|
|
|
|
|
|
|
2]
|
|
|
|
|
|
|
|
latency = context.search_space.get_model_latency(test_p)
|
|
|
|
|
|
|
|
if latency > self._max_latency:
|
|
|
|
|
|
|
|
self._current_reward = 0.0
|
|
|
|
|
|
|
|
_logger.info("reward: {}; latency: {}; flops: {}; tokens: {}".
|
|
|
|
|
|
|
|
format(self._current_reward, latency, flops,
|
|
|
|
|
|
|
|
self._current_tokens))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
_logger.info("reward: {}; flops: {}; tokens: {}".format(
|
|
|
|
|
|
|
|
self._current_reward, flops, self._current_tokens))
|
|
|
|
self._current_tokens = self._search_agent.update(
|
|
|
|
self._current_tokens = self._search_agent.update(
|
|
|
|
self._current_tokens, self._current_reward)
|
|
|
|
self._current_tokens, self._current_reward)
|
|
|
|