|
|
|
@ -133,11 +133,18 @@ class LightNASStrategy(Strategy):
|
|
|
|
|
self._retrain_epoch == 0 or
|
|
|
|
|
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
|
|
|
|
|
_logger.info("light nas strategy on_epoch_begin")
|
|
|
|
|
min_flops = -1
|
|
|
|
|
for _ in range(self._max_try_times):
|
|
|
|
|
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
|
|
|
|
|
self._current_tokens)
|
|
|
|
|
context.eval_graph.program = test_p
|
|
|
|
|
flops = context.eval_graph.flops()
|
|
|
|
|
if min_flops == -1:
|
|
|
|
|
min_flops = flops
|
|
|
|
|
min_tokens = self._current_tokens[:]
|
|
|
|
|
else:
|
|
|
|
|
if flops < min_flops:
|
|
|
|
|
min_tokens = self._current_tokens[:]
|
|
|
|
|
if self._max_latency > 0:
|
|
|
|
|
latency = context.search_space.get_model_latency(test_p)
|
|
|
|
|
_logger.info("try [{}] with latency {} flops {}".format(
|
|
|
|
@ -147,7 +154,8 @@ class LightNASStrategy(Strategy):
|
|
|
|
|
self._current_tokens, flops))
|
|
|
|
|
if flops > self._max_flops or (self._max_latency > 0 and
|
|
|
|
|
latency > self._max_latency):
|
|
|
|
|
self._current_tokens = self._search_agent.next_tokens()
|
|
|
|
|
self._current_tokens = self._controller.next_tokens(
|
|
|
|
|
min_tokens)
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|