Fix param base trainable set failed (#28756)

* fix param base trainable set failed

* add unittest

* fix typo

* polish comment
musl/fix_failed_unittests_in_musl
Chen Weihang 4 years ago committed by GitHub
parent b969c32ab1
commit 0ed80e09fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2858,6 +2858,12 @@ class Block(object):
param = ParamBase(*args, **kwargs)
else:
param = Parameter(global_block, *args, **kwargs)
# NOTE: Why only set stop_gradient=False in static mode
# Because in dygraph mode, the `stop_gradient` and `trainable`
# are related, and `trainable` default vallue is `True` or
# it is specified by users, there is no need to set
# `stop_gradient` for ParamBase here.
param.stop_gradient = False
if 'initializer' in kwargs:
def _is_inited_by(block, var):
@ -2884,7 +2890,6 @@ class Block(object):
pass
else:
initializer(param, self)
param.stop_gradient = False
return param
def append_op(self, *args, **kwargs):

@ -3683,5 +3683,24 @@ class TestMetricsDetectionMap(unittest.TestCase):
print(str(program))
class ExampleNet(paddle.nn.Layer):
def __init__(self):
super(ExampleNet, self).__init__()
self.weight = self.create_parameter(
shape=[1, 1], attr=paddle.ParamAttr(trainable=False))
def forward(self):
# only for test parameter trainable attr
pass
class TestLayerParameterTrainableSet(unittest.TestCase):
def test_layer_parameter_set(self):
with fluid.dygraph.guard():
net = ExampleNet()
self.assertFalse(net.weight.trainable)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()

Loading…
Cancel
Save