|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|
|
|
|
from .wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
from .core import VarDesc
|
|
|
|
|
from . import unique_name
|
|
|
|
|
from .imperative import base
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
|
|
|
|
@ -165,7 +166,8 @@ class ConstantInitializer(Initializer):
|
|
|
|
|
'force_cpu': self._force_cpu or force_init_on_cpu()
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -244,7 +246,8 @@ class UniformInitializer(Initializer):
|
|
|
|
|
attrs={"in_dtype": out_var.dtype,
|
|
|
|
|
"out_dtype": var.dtype})
|
|
|
|
|
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -322,7 +325,8 @@ class NormalInitializer(Initializer):
|
|
|
|
|
outputs={"Out": var},
|
|
|
|
|
attrs={"in_dtype": out_var.dtype,
|
|
|
|
|
"out_dtype": var.dtype})
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -400,7 +404,8 @@ class TruncatedNormalInitializer(Initializer):
|
|
|
|
|
outputs={"Out": var},
|
|
|
|
|
attrs={"in_dtype": out_var.dtype,
|
|
|
|
|
"out_dtype": var.dtype})
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -505,7 +510,8 @@ class XavierInitializer(Initializer):
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -605,7 +611,8 @@ class MSRAInitializer(Initializer):
|
|
|
|
|
"seed": self._seed
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -703,7 +710,8 @@ class BilinearInitializer(Initializer):
|
|
|
|
|
'shape': list(shape),
|
|
|
|
|
value_name: values
|
|
|
|
|
})
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -761,7 +769,8 @@ class NumpyArrayInitializer(Initializer):
|
|
|
|
|
value_name: values
|
|
|
|
|
},
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
var.op = op
|
|
|
|
|
if not base.enabled():
|
|
|
|
|
var.op = op
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|