add initial_accumulator_value for adagrad

test=develop
revert-15774-anakin_subgraph_engine
xuezhong 6 years ago
parent c1092374fc
commit 20e579ef2a

@ -662,7 +662,8 @@ class AdagradOptimizer(Optimizer):
learning_rate, learning_rate,
epsilon=1.0e-6, epsilon=1.0e-6,
regularization=None, regularization=None,
name=None): name=None,
initial_accumulator_value=0.1):
assert learning_rate is not None assert learning_rate is not None
assert epsilon is not None assert epsilon is not None
super(AdagradOptimizer, self).__init__( super(AdagradOptimizer, self).__init__(
@ -671,6 +672,7 @@ class AdagradOptimizer(Optimizer):
name=name) name=name)
self.type = "adagrad" self.type = "adagrad"
self._epsilon = epsilon self._epsilon = epsilon
self.initial_accumulator_value = initial_accumulator_value
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
@ -683,6 +685,16 @@ class AdagradOptimizer(Optimizer):
moment_acc = self._get_accumulator(self._moment_acc_str, moment_acc = self._get_accumulator(self._moment_acc_str,
param_and_grad[0]) param_and_grad[0])
startup_block = framework.default_startup_program().global_block()
startup_block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [moment_acc]},
attrs={
'dtype': moment_acc.dtype,
'value': self.initial_accumulator_value,
'shape': moment_acc.shape,
})
# Create the adagrad optimizer op # Create the adagrad optimizer op
adagrad_op = block.append_op( adagrad_op = block.append_op(

@ -274,7 +274,7 @@ class TestAdagradOptimizer(unittest.TestCase):
# Check init_program # Check init_program
init_ops = init_program.global_block().ops init_ops = init_program.global_block().ops
self.assertEqual(len(init_ops), 2) self.assertEqual(len(init_ops), 3)
self.assertEqual(init_ops[0].type, "fill_constant") self.assertEqual(init_ops[0].type, "fill_constant")
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
self.assertEqual(init_ops[1].type, "fill_constant") self.assertEqual(init_ops[1].type, "fill_constant")

Loading…
Cancel
Save