|
|
|
@ -62,13 +62,23 @@ def init_on_cpu():
|
|
|
|
|
_force_init_on_cpu_ = pre_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_inited_by(block, var, init_op_type):
|
|
|
|
|
def _is_inited_by(block, var, init_op_types):
|
|
|
|
|
for op in block.ops:
|
|
|
|
|
if var.name in op.output_arg_names and op.type == init_op_type:
|
|
|
|
|
if var.name in op.output_arg_names and op.type in init_op_types:
|
|
|
|
|
return op
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_duplicated_init_op(op1, op2):
|
|
|
|
|
if op1.block == op2.block and \
|
|
|
|
|
op1.type == op2.type and \
|
|
|
|
|
op1.input_arg_names == op2.output_arg_names and \
|
|
|
|
|
op1.idx != op2.idx and \
|
|
|
|
|
op1.all_attrs == op2.all_attrs:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Initializer(object):
|
|
|
|
|
"""Base class for variable initializers
|
|
|
|
|
|
|
|
|
@ -154,9 +164,7 @@ class ConstantInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'fill_constant')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
init_op = _is_inited_by(block, var, ['fill_constant'])
|
|
|
|
|
# Initialization Ops should be prepended and not appended
|
|
|
|
|
op = block._prepend_op(
|
|
|
|
|
type="fill_constant",
|
|
|
|
@ -167,6 +175,9 @@ class ConstantInitializer(Initializer):
|
|
|
|
|
"value": float(self._value),
|
|
|
|
|
'force_cpu': self._force_cpu or force_init_on_cpu()
|
|
|
|
|
})
|
|
|
|
|
if init_op is not None and _is_duplicated_init_op(init_op, op):
|
|
|
|
|
block._remove_op(0)
|
|
|
|
|
return init_op
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -209,9 +220,7 @@ class UniformInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
init_op = _is_inited_by(block, var, ['uniform_random'])
|
|
|
|
|
# Initialization Ops should be prepended and not appended
|
|
|
|
|
if self._seed == 0:
|
|
|
|
|
self._seed = block.program.random_seed
|
|
|
|
@ -225,6 +234,9 @@ class UniformInitializer(Initializer):
|
|
|
|
|
"max": self._high,
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
})
|
|
|
|
|
if init_op is not None and _is_duplicated_init_op(init_op, op):
|
|
|
|
|
block._remove_op(0)
|
|
|
|
|
return init_op
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -266,9 +278,7 @@ class NormalInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'gaussian_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
init_op = _is_inited_by(block, var, ['gaussian_random'])
|
|
|
|
|
# Initialization Ops should be prepended and not appended
|
|
|
|
|
if self._seed == 0:
|
|
|
|
|
self._seed = block.program.random_seed
|
|
|
|
@ -282,6 +292,9 @@ class NormalInitializer(Initializer):
|
|
|
|
|
"std": self._std_dev,
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
})
|
|
|
|
|
if init_op is not None and _is_duplicated_init_op(init_op, op):
|
|
|
|
|
block._remove_op(0)
|
|
|
|
|
return init_op
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -351,9 +364,8 @@ class XavierInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
init_op = _is_inited_by(block, var,
|
|
|
|
|
['uniform_random', 'gaussian_random'])
|
|
|
|
|
|
|
|
|
|
f_in, f_out = self._compute_fans(var)
|
|
|
|
|
|
|
|
|
@ -389,6 +401,9 @@ class XavierInitializer(Initializer):
|
|
|
|
|
"std": std,
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
})
|
|
|
|
|
if init_op is not None and _is_duplicated_init_op(init_op, op):
|
|
|
|
|
block._remove_op(0)
|
|
|
|
|
return init_op
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -454,13 +469,8 @@ class MSRAInitializer(Initializer):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
init_op = _is_inited_by(block, var, 'uniform_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
|
|
|
|
|
init_op = _is_inited_by(block, var, 'gaussian_random')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
init_op = _is_inited_by(block, var,
|
|
|
|
|
['uniform_random', 'gaussian_random'])
|
|
|
|
|
|
|
|
|
|
f_in, f_out = self._compute_fans(var)
|
|
|
|
|
|
|
|
|
@ -495,6 +505,9 @@ class MSRAInitializer(Initializer):
|
|
|
|
|
"std": std,
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
})
|
|
|
|
|
if init_op is not None and _is_duplicated_init_op(init_op, op):
|
|
|
|
|
block._remove_op(0)
|
|
|
|
|
return init_op
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -561,8 +574,6 @@ class BilinearInitializer(Initializer):
|
|
|
|
|
raise ValueError("block must be framework.Block.")
|
|
|
|
|
|
|
|
|
|
init_op = _is_inited_by(block, var, 'assign_value')
|
|
|
|
|
if init_op is not None:
|
|
|
|
|
return init_op
|
|
|
|
|
|
|
|
|
|
shape = var.shape
|
|
|
|
|
if len(shape) != 4:
|
|
|
|
@ -597,6 +608,9 @@ class BilinearInitializer(Initializer):
|
|
|
|
|
'shape': list(shape),
|
|
|
|
|
value_name: values
|
|
|
|
|
})
|
|
|
|
|
if init_op is not None and _is_duplicated_init_op(init_op, op):
|
|
|
|
|
block._remove_op(0)
|
|
|
|
|
return init_op
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|