|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|
|
|
|
from .wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
from .core import VarDesc
|
|
|
|
|
from . import unique_name
|
|
|
|
|
from .imperative import base
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
|
|
|
|
@ -165,6 +166,7 @@ class ConstantInitializer(Initializer):
|
|
|
|
|
'force_cpu': self._force_cpu or force_init_on_cpu()
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -244,6 +246,7 @@ class UniformInitializer(Initializer):
|
|
|
|
|
attrs={"in_dtype": out_var.dtype,
|
|
|
|
|
"out_dtype": var.dtype})
|
|
|
|
|
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -322,6 +325,7 @@ class NormalInitializer(Initializer):
|
|
|
|
|
outputs={"Out": var},
|
|
|
|
|
attrs={"in_dtype": out_var.dtype,
|
|
|
|
|
"out_dtype": var.dtype})
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -400,6 +404,7 @@ class TruncatedNormalInitializer(Initializer):
|
|
|
|
|
outputs={"Out": var},
|
|
|
|
|
attrs={"in_dtype": out_var.dtype,
|
|
|
|
|
"out_dtype": var.dtype})
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -505,6 +510,7 @@ class XavierInitializer(Initializer):
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -605,6 +611,7 @@ class MSRAInitializer(Initializer):
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -703,6 +710,7 @@ class BilinearInitializer(Initializer):
|
|
|
|
|
'shape': list(shape),
|
|
|
|
|
value_name: values
|
|
|
|
|
})
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
@ -761,6 +769,7 @@ class NumpyArrayInitializer(Initializer):
|
|
|
|
|
value_name: values
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|