|
|
@ -98,7 +98,7 @@ class TestMomentumOptimizer(unittest.TestCase):
|
|
|
|
self.assertEqual(len(opts), 1)
|
|
|
|
self.assertEqual(len(opts), 1)
|
|
|
|
sgd_op = opts[0]
|
|
|
|
sgd_op = opts[0]
|
|
|
|
self.assertEqual(sgd_op.type, "momentum")
|
|
|
|
self.assertEqual(sgd_op.type, "momentum")
|
|
|
|
self.assertFalse(sgd_op.attr('useNesterov'))
|
|
|
|
self.assertFalse(sgd_op.attr('use_nesterov'))
|
|
|
|
|
|
|
|
|
|
|
|
# Check accumulators
|
|
|
|
# Check accumulators
|
|
|
|
accumulators = momentum_optimizer.get_accumulators()
|
|
|
|
accumulators = momentum_optimizer.get_accumulators()
|
|
|
@ -143,7 +143,7 @@ class TestMomentumOptimizer(unittest.TestCase):
|
|
|
|
self.assertEqual(len(opts), 1)
|
|
|
|
self.assertEqual(len(opts), 1)
|
|
|
|
sgd_op = opts[0]
|
|
|
|
sgd_op = opts[0]
|
|
|
|
self.assertEqual(sgd_op.type, "momentum")
|
|
|
|
self.assertEqual(sgd_op.type, "momentum")
|
|
|
|
self.assertTrue(sgd_op.attr('useNesterov'))
|
|
|
|
self.assertTrue(sgd_op.attr('use_nesterov'))
|
|
|
|
|
|
|
|
|
|
|
|
# Check accumulators
|
|
|
|
# Check accumulators
|
|
|
|
accumulators = momentum_optimizer.get_accumulators()
|
|
|
|
accumulators = momentum_optimizer.get_accumulators()
|
|
|
|