"add init seed" (#6221)

* "add init seed"

* "fix compile error"

* "add program level seed setting"

* "fixed based on comments"
release/0.11.0
dzhwinter 8 years ago committed by GitHub
parent a0c1190f76
commit 4eac85c60b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -512,6 +512,7 @@ class Program(object):
self.desc = core.ProgramDesc()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
self._seed = 0
def __str__(self):
return self.to_string(True)
@ -564,6 +565,16 @@ class Program(object):
p.sync_with_cpp()
return p
@property
def random_seed(self):
return self._seed
@random_seed.setter
def random_seed(self, seed):
if not isinstance(seed, int):
raise ValueError("Seed must be a integer.")
self._seed = seed
def __repr__(self):
return str(self)

@ -132,6 +132,8 @@ class UniformInitializer(Initializer):
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
op = block.prepend_op(
type="uniform_random",
outputs={"Out": var},
@ -180,6 +182,8 @@ class NormalInitializer(Initializer):
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
op = block.prepend_op(
type="gaussian_random",
outputs={"Out": var},
@ -255,6 +259,9 @@ class XavierInitializer(Initializer):
fan_in = f_in if self._fan_in is None else self._fan_in
fan_out = f_out if self._fan_out is None else self._fan_out
if self._seed == 0:
self._seed = block.program.random_seed
if self._uniform:
limit = np.sqrt(6.0 / float(fan_in + fan_out))
op = block.prepend_op(
@ -338,6 +345,9 @@ class MSRAInitializer(Initializer):
# If fan_in is passed, use it
fan_in = f_in if self._fan_in is None else self._fan_in
if self._seed == 0:
self._seed = block.program.random_seed
if self._uniform:
limit = np.sqrt(6.0 / float(fan_in))
op = block.prepend_op(

@ -60,6 +60,29 @@ class TestUniformInitializer(unittest.TestCase):
self.assertAlmostEqual(init_op.attr('max'), 1.0, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_uniform_initializer_random_seed(self):
"""Test the uniform initializer with manually setting seed
"""
program = framework.Program()
program.random_seed = 123
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer())
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer(seed=456))
init_op = block.ops[1]
self.assertEqual(init_op.attr("seed"), 123)
init_op1 = block.ops[0]
self.assertEqual(init_op1.attr("seed"), 456)
def test_uniform_initializer(self):
"""Test uniform initializer with supplied attributes
"""

Loading…
Cancel
Save