Adding a framework for variable initializers (#5232)
parent
9b70b6a1bb
commit
61eafbe09d
@ -0,0 +1,109 @@
|
||||
import paddle.v2.framework.framework as framework
|
||||
|
||||
__all__ = ['ConstantInitializer', 'UniformInitializer']
|
||||
|
||||
|
||||
class Initializer(object):
|
||||
"""Base class for variable initializers
|
||||
|
||||
Defines the common interface of variable initializers.
|
||||
They add operations to the init program that are used
|
||||
to initialize variables. Users should not use this class
|
||||
directly, but need to use one of its implementations.
|
||||
"""
|
||||
|
||||
def __init_(self):
|
||||
pass
|
||||
|
||||
def __call__(self, param, block):
|
||||
"""Add corresponding initialization operations to the network
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ConstantInitializer(Initializer):
|
||||
"""Implements the constant initializer
|
||||
"""
|
||||
|
||||
def __init__(self, value=0.0):
|
||||
"""Constructor for ConstantInitializer
|
||||
|
||||
Args:
|
||||
value: constant value to initialize the variable
|
||||
"""
|
||||
assert value is not None
|
||||
super(ConstantInitializer, self).__init__()
|
||||
self._value = value
|
||||
|
||||
def __call__(self, var, block):
|
||||
"""Add constant initialization ops for a variable
|
||||
|
||||
Args:
|
||||
var: Variable that needs to be initialized
|
||||
block: The block in which initialization ops
|
||||
should be added
|
||||
|
||||
Returns:
|
||||
the initialization op
|
||||
"""
|
||||
assert isinstance(var, framework.Variable)
|
||||
assert isinstance(block, framework.Block)
|
||||
# Initialization Ops should be prepended and not appended
|
||||
op = block.prepend_op(
|
||||
type="fill_constant",
|
||||
outputs={"Out": var},
|
||||
attrs={
|
||||
"shape": var.shape,
|
||||
"data_type": int(var.data_type),
|
||||
"value": self._value
|
||||
})
|
||||
var.op = op
|
||||
return op
|
||||
|
||||
|
||||
class UniformInitializer(Initializer):
|
||||
"""Implements for random uniform distribution initializer
|
||||
"""
|
||||
|
||||
def __init__(self, low=-1.0, high=1.0, seed=0):
|
||||
"""Constructor for UniformInitializer
|
||||
|
||||
Args:
|
||||
low: lower boundary of the uniform distribution
|
||||
high: upper boundary of the uniform distribution
|
||||
seed: random seed
|
||||
"""
|
||||
assert low is not None
|
||||
assert high is not None
|
||||
assert seed is not None
|
||||
super(UniformInitializer, self).__init__()
|
||||
self._low = low
|
||||
self._high = high
|
||||
self._seed = seed
|
||||
|
||||
def __call__(self, var, block):
|
||||
"""Add uniform distribution initialization ops for a variable
|
||||
|
||||
Args:
|
||||
var: Variable that needs to be initialized
|
||||
block: The block in which initialization ops
|
||||
should be added
|
||||
|
||||
Returns:
|
||||
the initialization op
|
||||
"""
|
||||
assert isinstance(var, framework.Variable)
|
||||
assert isinstance(block, framework.Block)
|
||||
# Initialization Ops should be prepended and not appended
|
||||
op = block.prepend_op(
|
||||
type="uniform_random",
|
||||
outputs={"Out": var},
|
||||
attrs={
|
||||
"shape": var.shape,
|
||||
"data_type": int(var.data_type),
|
||||
"min": self._low,
|
||||
"max": self._high,
|
||||
"seed": self._seed
|
||||
})
|
||||
var.op = op
|
||||
return op
|
Loading…
Reference in new issue