|
|
|
@ -103,6 +103,51 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer):
|
|
|
|
|
# recompute
|
|
|
|
|
self.assertIn('subprog', ''.join(outs))
|
|
|
|
|
|
|
|
|
|
def test_amp_recompute_lars_optimizer(self):
|
|
|
|
|
""" test amp + recompute """
|
|
|
|
|
train_prog, startup_prog = fluid.Program(), fluid.Program()
|
|
|
|
|
avg_cost, strategy = self.net(train_prog, startup_prog)
|
|
|
|
|
self.set_strategy(strategy, 'amp')
|
|
|
|
|
self.set_strategy(strategy, 'recompute')
|
|
|
|
|
self.set_strategy(strategy, 'lars')
|
|
|
|
|
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
|
|
|
|
|
|
|
|
|
|
strategy = fleet._final_strategy()
|
|
|
|
|
|
|
|
|
|
ops = [op.type for op in avg_cost.block.ops]
|
|
|
|
|
outs = [
|
|
|
|
|
op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul'
|
|
|
|
|
]
|
|
|
|
|
self.assertIn('cast', ops)
|
|
|
|
|
self.assertIn('check_finite_and_unscale', ops)
|
|
|
|
|
|
|
|
|
|
# recompute
|
|
|
|
|
self.assertIn('subprog', ''.join(outs))
|
|
|
|
|
|
|
|
|
|
# lars
|
|
|
|
|
self.assertIn('lars_momentum', ops)
|
|
|
|
|
|
|
|
|
|
def test_amp_recompute_lamb_optimizer(self):
|
|
|
|
|
train_prog, startup_prog = fluid.Program(), fluid.Program()
|
|
|
|
|
avg_cost, strategy = self.net(train_prog, startup_prog)
|
|
|
|
|
self.set_strategy(strategy, 'amp')
|
|
|
|
|
self.set_strategy(strategy, 'recompute')
|
|
|
|
|
self.set_strategy(strategy, 'lamb')
|
|
|
|
|
self.optimizer(avg_cost, strategy, train_prog, startup_prog, 'adam')
|
|
|
|
|
|
|
|
|
|
ops = [op.type for op in avg_cost.block.ops]
|
|
|
|
|
outs = [
|
|
|
|
|
op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul'
|
|
|
|
|
]
|
|
|
|
|
self.assertIn('cast', ops)
|
|
|
|
|
self.assertIn('check_finite_and_unscale', ops)
|
|
|
|
|
|
|
|
|
|
# recompute
|
|
|
|
|
self.assertIn('subprog', ''.join(outs))
|
|
|
|
|
|
|
|
|
|
# lamb
|
|
|
|
|
self.assertIn('lamb', ops)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|