|
|
|
@ -35,7 +35,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
|
|
|
|
def check_dgc_momentum_optimizer(self,
|
|
|
|
def check_dgc_momentum_optimizer(self,
|
|
|
|
dims=[5, 10, 8],
|
|
|
|
dims=[5, 10, 8],
|
|
|
|
name="momentum",
|
|
|
|
name="momentum",
|
|
|
|
regularization=None):
|
|
|
|
regularization=None,
|
|
|
|
|
|
|
|
use_recompute=False):
|
|
|
|
init_program = framework.Program()
|
|
|
|
init_program = framework.Program()
|
|
|
|
program = framework.Program()
|
|
|
|
program = framework.Program()
|
|
|
|
block = program.global_block()
|
|
|
|
block = program.global_block()
|
|
|
|
@ -72,6 +73,13 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
|
|
|
|
local_grad_clip_norm=1.0,
|
|
|
|
local_grad_clip_norm=1.0,
|
|
|
|
num_trainers=2,
|
|
|
|
num_trainers=2,
|
|
|
|
regularization=regularization)
|
|
|
|
regularization=regularization)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_recompute:
|
|
|
|
|
|
|
|
dgc_momentum_optimizer = optimizer.RecomputeOptimizer(
|
|
|
|
|
|
|
|
dgc_momentum_optimizer)
|
|
|
|
|
|
|
|
dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators
|
|
|
|
|
|
|
|
dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str
|
|
|
|
|
|
|
|
|
|
|
|
mean_out = block.create_var(
|
|
|
|
mean_out = block.create_var(
|
|
|
|
dtype="float32", shape=[1], lod_level=0, name="mean.out")
|
|
|
|
dtype="float32", shape=[1], lod_level=0, name="mean.out")
|
|
|
|
block.append_op(
|
|
|
|
block.append_op(
|
|
|
|
@ -112,8 +120,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
|
|
|
|
self.assertAlmostEqual(op.attr('regular_coeff'), coeff)
|
|
|
|
self.assertAlmostEqual(op.attr('regular_coeff'), coeff)
|
|
|
|
print("dgc regular_coeff=" + str(coeff))
|
|
|
|
print("dgc regular_coeff=" + str(coeff))
|
|
|
|
|
|
|
|
|
|
|
|
with open("test_dgc_optimizer_" + name + ".log", "w") as f:
|
|
|
|
# for local test debug
|
|
|
|
program_to_code(program, fout=f)
|
|
|
|
#with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f:
|
|
|
|
|
|
|
|
# program_to_code(program, fout=f)
|
|
|
|
|
|
|
|
|
|
|
|
def test_momentum_without_dgc(self):
|
|
|
|
def test_momentum_without_dgc(self):
|
|
|
|
self.check_dgc_momentum_optimizer(
|
|
|
|
self.check_dgc_momentum_optimizer(
|
|
|
|
@ -130,6 +139,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
|
|
|
|
self.check_dgc_momentum_optimizer(
|
|
|
|
self.check_dgc_momentum_optimizer(
|
|
|
|
dims=[16, 1024, 8], name="dgc_momentum")
|
|
|
|
dims=[16, 1024, 8], name="dgc_momentum")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_momentum_with_dgc_recompute(self):
|
|
|
|
|
|
|
|
# 16 * 1024 = 16384, use dgc momentum
|
|
|
|
|
|
|
|
self.check_dgc_momentum_optimizer(
|
|
|
|
|
|
|
|
dims=[16, 1024, 8],
|
|
|
|
|
|
|
|
name="dgc_momentum",
|
|
|
|
|
|
|
|
regularization=regularizer.L2Decay(1e-4),
|
|
|
|
|
|
|
|
use_recompute=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|
|
|
|
unittest.main()
|
|
|
|
|