|
|
|
@ -228,9 +228,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
with program_guard(Program(), Program()):
|
|
|
|
|
compare_model = paddle.nn.Sequential(
|
|
|
|
|
paddle.nn.Conv2D(3, 5, 3), paddle.nn.BatchNorm2D(5))
|
|
|
|
|
paddle.nn.Conv2D(3, 5, 3),
|
|
|
|
|
paddle.nn.BatchNorm2D(5), paddle.nn.BatchNorm2D(5))
|
|
|
|
|
model = paddle.nn.Sequential(
|
|
|
|
|
paddle.nn.Conv2D(3, 5, 3), paddle.nn.BatchNorm2D(5))
|
|
|
|
|
paddle.nn.Conv2D(3, 5, 3),
|
|
|
|
|
paddle.nn.BatchNorm2D(5),
|
|
|
|
|
paddle.nn.BatchNorm2D(
|
|
|
|
|
5,
|
|
|
|
|
weight_attr=fluid.ParamAttr(name='bn.scale'),
|
|
|
|
|
bias_attr=fluid.ParamAttr(name='bn.bias')))
|
|
|
|
|
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
|
for idx, sublayer in enumerate(compare_model.sublayers()):
|
|
|
|
|
if isinstance(sublayer, paddle.nn.BatchNorm2D):
|
|
|
|
|