|
|
|
@ -62,6 +62,13 @@ def init_on_cpu():
|
|
|
|
|
_force_init_on_cpu_ = pre_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_inited_by(block, var, init_op_type):
|
|
|
|
|
for op in block.ops:
|
|
|
|
|
if var.name in op.output_arg_names and op.type == init_op_type:
|
|
|
|
|
return op
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Initializer(object):
|
|
|
|
|
"""Base class for variable initializers
|
|
|
|
|
|
|
|
|
@ -147,6 +154,9 @@ class ConstantInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
# Initialization Ops should be prepended and not appended
|
|
|
|
|
op = block._prepend_op(
|
|
|
|
|
type="fill_constant",
|
|
|
|
@ -199,6 +209,9 @@ class UniformInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
# Initialization Ops should be prepended and not appended
|
|
|
|
|
if self._seed == 0:
|
|
|
|
|
self._seed = block.program.random_seed
|
|
|
|
@ -253,6 +266,9 @@ class NormalInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
# Initialization Ops should be prepended and not appended
|
|
|
|
|
if self._seed == 0:
|
|
|
|
|
self._seed = block.program.random_seed
|
|
|
|
@ -335,6 +351,10 @@ class XavierInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
|
|
|
|
|
f_in, f_out = self._compute_fans(var)
|
|
|
|
|
|
|
|
|
|
# If fan_in and fan_out are passed, use them
|
|
|
|
@ -434,6 +454,10 @@ class MSRAInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
|
|
|
|
|
f_in, f_out = self._compute_fans(var)
|
|
|
|
|
|
|
|
|
|
# If fan_in is passed, use it
|
|
|
|
@ -532,6 +556,10 @@ class BilinearInitializer(Initializer):
|
|
|
|
|
if not isinstance(block, framework.Block):
|
|
|
|
|
raise ValueError("block must be framework.Block.")
|
|
|
|
|
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
|
|
|
|
|
shape = var.shape
|
|
|
|
|
if len(shape) != 4:
|
|
|
|
|
raise ValueError("the length of shape must be 4.")
|
|
|
|
|