fix syncbn convert (#30158)

* fix syncbn convet

* add unittest
revert-31562-mean
ceci3 4 years ago committed by GitHub
parent adac38c506
commit 6a19e41f1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,6 +25,7 @@ import six
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.nn as nn
from paddle.fluid import compiler from paddle.fluid import compiler
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
@ -244,5 +245,34 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
isinstance(model[idx], paddle.nn.SyncBatchNorm), True) isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
class TestConvertSyncBatchNormCase2(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return
class Net(nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2D(3, 5, 3)
self.bn = []
bn = self.add_sublayer('bn', nn.BatchNorm2D(5))
self.bn.append(bn)
def forward(self, x):
x = self.conv1(x)
for bn in self.bn:
x = bn(x)
return x
model = nn.Sequential()
model.add_sublayer('net1', Net())
model.add_sublayer('net2', Net())
compare_model = nn.Sequential()
compare_model.add_sublayer('net1', Net())
compare_model.add_sublayer('net2', Net())
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
self.assertEqual(len(compare_model.sublayers()), len(model.sublayers()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -1142,7 +1142,7 @@ class SyncBatchNorm(_BatchNormBase):
layer_output._mean = layer._mean layer_output._mean = layer._mean
layer_output._variance = layer._variance layer_output._variance = layer._variance
for name, sublayer in layer.named_sublayers(): for name, sublayer in layer.named_children():
layer_output.add_sublayer(name, layer_output.add_sublayer(name,
cls.convert_sync_batchnorm(sublayer)) cls.convert_sync_batchnorm(sublayer))
del layer del layer

Loading…
Cancel
Save