|
|
@ -13,7 +13,7 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
from ..core.strategy import Strategy
|
|
|
|
from ..core.strategy import Strategy
|
|
|
|
from ....framework import Program, program_guard
|
|
|
|
from ....framework import Program, Variable, program_guard
|
|
|
|
from .... import Executor
|
|
|
|
from .... import Executor
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
@ -74,8 +74,17 @@ class DistillationStrategy(Strategy):
|
|
|
|
startup_program = Program()
|
|
|
|
startup_program = Program()
|
|
|
|
with program_guard(graph.program, startup_program):
|
|
|
|
with program_guard(graph.program, startup_program):
|
|
|
|
context.distiller_optimizer._name = 'distillation_optimizer'
|
|
|
|
context.distiller_optimizer._name = 'distillation_optimizer'
|
|
|
|
context.distiller_optimizer.minimize(
|
|
|
|
|
|
|
|
graph.var(graph.out_nodes['loss'])._var)
|
|
|
|
# The learning rate variable may be created in other program.
|
|
|
|
|
|
|
|
# Update information in optimizer to make
|
|
|
|
|
|
|
|
# learning rate variable being accessible in current program.
|
|
|
|
|
|
|
|
optimizer = context.distiller_optimizer
|
|
|
|
|
|
|
|
if isinstance(optimizer._learning_rate, Variable):
|
|
|
|
|
|
|
|
optimizer._learning_rate_map[
|
|
|
|
|
|
|
|
graph.program] = optimizer._learning_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)
|
|
|
|
|
|
|
|
|
|
|
|
exe = Executor(context.place)
|
|
|
|
exe = Executor(context.place)
|
|
|
|
exe.run(startup_program, scope=context.scope)
|
|
|
|
exe.run(startup_program, scope=context.scope)
|
|
|
|
|
|
|
|
|
|
|
|