@ -26,7 +26,7 @@ from ..initializer import Normal, Constant
from ..framework import Variable, OpProtoHolder
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat
from .tensor import concat, assign
from . import utils
from .. import unique_name
from functools import reduce
@ -310,7 +310,10 @@ def assign(input, output=None, init_once=False):
if output is None:
if init_once:
output = helper.create_parameter(
attr=ParamAttr(), shape=input.shape, dtype=input.dtype)
attr=ParamAttr(),
shape=input.shape,
dtype=input.dtype,
default_initializer=Constant(0.0))
else:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)