Support sync batch norm. (#16121)
* Support Sync Batch Norm.
* Note, do not enable it in one device.
Usage:
build_strategy = fluid.BuildStrategy()
build_strategy.sync_batch_norm = True
binary = fluid.compiler.CompiledProgram(tp).with_data_parallel(
loss_name=loss_mean.name,
build_strategy=build_strategy)
revert-15164-speed/enhance_cudnn
parent
4ae23cc3c5
commit
8ad672a287
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> SyncBatchNormPass::ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const {
|
||||||
|
VLOG(3) << "Use synchronous batch norm";
|
||||||
|
for (const Node* n : graph->Nodes()) {
|
||||||
|
if (n->IsOp()) {
|
||||||
|
auto* op = n->Op();
|
||||||
|
if (op->Type() == "batch_norm") {
|
||||||
|
op->SetType("sync_batch_norm");
|
||||||
|
}
|
||||||
|
if (op->Type() == "batch_norm_grad") {
|
||||||
|
op->SetType("sync_batch_norm_grad");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_PASS(sync_batch_norm_pass, paddle::framework::ir::SyncBatchNormPass);
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "paddle/fluid/framework/ir/pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
class SyncBatchNormPass : public Pass {
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||||
|
const std::vector<std::string>& inputs,
|
||||||
|
const std::vector<std::string>& outputs) {
|
||||||
|
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||||
|
op->SetType(type);
|
||||||
|
op->SetAttr("name", name);
|
||||||
|
op->SetInput("X", inputs);
|
||||||
|
op->SetOutput("Out", outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// (a, conv_w)->conv2d->b
|
||||||
|
// (b, bn_scale, bn_bias, mean, var)->batch_norm
|
||||||
|
// ->(c, mean, var, save_mean, save_inv_var)
|
||||||
|
ProgramDesc BuildProgramDesc() {
|
||||||
|
ProgramDesc prog;
|
||||||
|
for (auto& v : std::vector<std::string>({"a", "conv_w", "b", "bn_scale",
|
||||||
|
"bn_bias", "mean", "var", "c",
|
||||||
|
"save_mean", "save_inv_var"})) {
|
||||||
|
auto* var = prog.MutableBlock(0)->Var(v);
|
||||||
|
if (v == "conv_w" || v == "bn_scale" || v == "bn_bias" || v == "mean" ||
|
||||||
|
v == "var") {
|
||||||
|
var->SetPersistable(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SetOp(&prog, "conv2d", "conv", std::vector<std::string>({"a", "conv_w"}),
|
||||||
|
std::vector<std::string>({"b"}));
|
||||||
|
SetOp(&prog, "batch_norm", "bn",
|
||||||
|
std::vector<std::string>({"b", "bn_scale", "bn_bias", "mean", "var"}),
|
||||||
|
std::vector<std::string>(
|
||||||
|
{"c", "mean", "var", "save_mean", "save_inv_var"}));
|
||||||
|
return prog;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(IsTestPass, basic) {
|
||||||
|
auto prog = BuildProgramDesc();
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||||
|
|
||||||
|
auto pass = PassRegistry::Instance().Get("sync_batch_norm_pass");
|
||||||
|
|
||||||
|
graph = pass->Apply(std::move(graph));
|
||||||
|
|
||||||
|
for (auto* node : graph->Nodes()) {
|
||||||
|
if (node->IsOp()) {
|
||||||
|
auto* op = node->Op();
|
||||||
|
auto op_name = boost::get<std::string>(op->GetAttr("name"));
|
||||||
|
if (op_name == "bn") {
|
||||||
|
ASSERT_EQ(op->Type(), "sync_batch_norm");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
USE_PASS(sync_batch_norm_pass);
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/batch_norm_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(sync_batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
||||||
|
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
|
||||||
|
REGISTER_OPERATOR(sync_batch_norm_grad, ops::BatchNormGradOp);
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,159 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import six
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid import compiler
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBatchNormOpTraining(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
#self.dtype = np.float32
|
||||||
|
self.dtype = np.float64
|
||||||
|
self.N = 32
|
||||||
|
self.C = 16
|
||||||
|
self.H = 64
|
||||||
|
self.W = 32
|
||||||
|
self.dshape = [self.N, self.C, self.H, self.W]
|
||||||
|
|
||||||
|
def build_program(self,
|
||||||
|
place,
|
||||||
|
layout,
|
||||||
|
seed,
|
||||||
|
sync_bn=False,
|
||||||
|
only_forward=False):
|
||||||
|
main = fluid.Program()
|
||||||
|
startup = fluid.Program()
|
||||||
|
main.random_seed = seed
|
||||||
|
startup.random_seed = seed
|
||||||
|
with fluid.unique_name.guard():
|
||||||
|
with fluid.program_guard(main, startup):
|
||||||
|
data = fluid.layers.data(
|
||||||
|
name='input',
|
||||||
|
shape=self.dshape,
|
||||||
|
dtype=self.dtype,
|
||||||
|
append_batch_size=False)
|
||||||
|
conv = fluid.layers.conv2d(
|
||||||
|
input=data,
|
||||||
|
num_filters=32,
|
||||||
|
filter_size=1,
|
||||||
|
param_attr=fluid.ParamAttr(name='conv2d_weight'),
|
||||||
|
bias_attr=False,
|
||||||
|
use_cudnn=False)
|
||||||
|
bn = fluid.layers.batch_norm(
|
||||||
|
conv,
|
||||||
|
param_attr=fluid.ParamAttr(name='bn_scale'),
|
||||||
|
bias_attr=fluid.ParamAttr(name='bn_bias'),
|
||||||
|
moving_mean_name='bn_moving_mean',
|
||||||
|
moving_variance_name='bn_moving_variance',
|
||||||
|
data_layout=layout,
|
||||||
|
is_test=only_forward)
|
||||||
|
sigmoid = fluid.layers.sigmoid(bn)
|
||||||
|
out = fluid.layers.reduce_sum(sigmoid)
|
||||||
|
if not sync_bn:
|
||||||
|
out = out / core.get_cuda_device_count()
|
||||||
|
if not only_forward:
|
||||||
|
sgd_opt = fluid.optimizer.SGD(learning_rate=0.0)
|
||||||
|
sgd_opt.backward(out)
|
||||||
|
return main, startup, [out, conv, bn]
|
||||||
|
|
||||||
|
def compare(self, place, layout, only_forward):
|
||||||
|
seed = 10
|
||||||
|
os.environ['FLAGS_cudnn_deterministic'] = "1"
|
||||||
|
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
|
||||||
|
# Single-GPU, N = 32 per GPU
|
||||||
|
main, startup, outs = self.build_program(place, layout, seed, False,
|
||||||
|
only_forward)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup)
|
||||||
|
fetch_names = [v.name for v in outs] + [
|
||||||
|
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
|
||||||
|
]
|
||||||
|
if not only_forward:
|
||||||
|
others = [
|
||||||
|
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
|
||||||
|
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD'
|
||||||
|
]
|
||||||
|
fetch_names += others
|
||||||
|
bn_fetches = exe.run(program=main,
|
||||||
|
feed={'input': data},
|
||||||
|
fetch_list=fetch_names)
|
||||||
|
|
||||||
|
#####################################################################
|
||||||
|
# Multi-GPUs, self.N / core.get_cuda_device_count() per GPU
|
||||||
|
main, startup, outs = self.build_program(place, layout, seed, True,
|
||||||
|
only_forward)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup)
|
||||||
|
fetch_names = [v.name for v in outs] + [
|
||||||
|
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
|
||||||
|
]
|
||||||
|
if not only_forward:
|
||||||
|
others = [
|
||||||
|
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
|
||||||
|
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD'
|
||||||
|
]
|
||||||
|
fetch_names += others
|
||||||
|
for nm in fetch_names:
|
||||||
|
fv = fluid.framework._get_var(str(nm), program=main)
|
||||||
|
fv.persistable = True
|
||||||
|
build_strategy = fluid.BuildStrategy()
|
||||||
|
build_strategy.sync_batch_norm = True
|
||||||
|
build_strategy.enable_inplace = False
|
||||||
|
build_strategy.memory_optimize = False
|
||||||
|
comp_prog = compiler.CompiledProgram(main).with_data_parallel(
|
||||||
|
outs[0].name if not only_forward else None,
|
||||||
|
build_strategy=build_strategy)
|
||||||
|
sync_bn_fetches = exe.run(program=comp_prog,
|
||||||
|
feed={'input': data},
|
||||||
|
fetch_list=fetch_names)
|
||||||
|
|
||||||
|
for i in six.moves.xrange(1, len(sync_bn_fetches)):
|
||||||
|
bn_val = bn_fetches[i]
|
||||||
|
sync_bn_val = sync_bn_fetches[i]
|
||||||
|
if sync_bn_val.shape != bn_val.shape:
|
||||||
|
sync_bn_val = sync_bn_val[:bn_val.shape[0]]
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(
|
||||||
|
bn_val, sync_bn_val, atol=1e-3),
|
||||||
|
"Output (" + fetch_names[i] + ") has diff. \n" + "\nBN " +
|
||||||
|
str(bn_val) + "\n" + "Sync BN " + str(sync_bn_val))
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
if not core.is_compiled_with_cuda():
|
||||||
|
return
|
||||||
|
|
||||||
|
places = [core.CUDAPlace(0)]
|
||||||
|
for place in places:
|
||||||
|
for layout in ["NCHW", "NHWC"]:
|
||||||
|
self.compare(place, layout, False)
|
||||||
|
|
||||||
|
def test_infer(self):
|
||||||
|
if not core.is_compiled_with_cuda():
|
||||||
|
return
|
||||||
|
|
||||||
|
places = [core.CUDAPlace(0)]
|
||||||
|
for place in places:
|
||||||
|
for layout in ["NCHW", "NHWC"]:
|
||||||
|
self.compare(place, layout, True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue