|
|
|
@ -15,6 +15,7 @@ class Variable(object):
|
|
|
|
|
shape=None,
|
|
|
|
|
dtype=None,
|
|
|
|
|
lod_level=None,
|
|
|
|
|
persistable=False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
self.block = block
|
|
|
|
|
|
|
|
|
@ -70,6 +71,17 @@ class Variable(object):
|
|
|
|
|
"lod_level is {2}. They are not "
|
|
|
|
|
"matched".format(self.name, self.lod_level,
|
|
|
|
|
lod_level))
|
|
|
|
|
if persistable is not None:
|
|
|
|
|
if is_new_var:
|
|
|
|
|
self.desc.set_persistable(persistable)
|
|
|
|
|
else:
|
|
|
|
|
if persistable != self.persistable:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Variable {0} has been created before."
|
|
|
|
|
"The previous persistable is {1}; the new "
|
|
|
|
|
"persistable is {2}. They are not matched".format(
|
|
|
|
|
self.name, self.persistable, persistable))
|
|
|
|
|
|
|
|
|
|
self.block.vars[name] = self
|
|
|
|
|
self.op = None
|
|
|
|
|
|
|
|
|
@ -80,6 +92,10 @@ class Variable(object):
|
|
|
|
|
|
|
|
|
|
__repr__ = __str__
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def persistable(self):
|
|
|
|
|
return self.desc.persistable()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def name(self):
|
|
|
|
|
return self.desc.name()
|
|
|
|
@ -445,7 +461,9 @@ class Parameter(Variable):
|
|
|
|
|
if each < 0:
|
|
|
|
|
raise ValueError("Parameter shape should not be related with "
|
|
|
|
|
"batch-size")
|
|
|
|
|
Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
|
|
|
|
|
|
|
|
|
|
Variable.__init__(
|
|
|
|
|
self, block, persistable=True, shape=shape, dtype=dtype, **kwargs)
|
|
|
|
|
self.trainable = kwargs.get('trainable', True)
|
|
|
|
|
self.init_attr = kwargs.get('initialize_attr', {
|
|
|
|
|
'type': 'uniform_random',
|
|
|
|
|