Support sync batch norm. (#16121)
* Support Sync Batch Norm. * Note, do not enable it in one device. Usage: build_strategy = fluid.BuildStrategy() build_strategy.sync_batch_norm = True binary = fluid.compiler.CompiledProgram(tp).with_data_parallel( loss_name=loss_mean.name, build_strategy=build_strategy)revert-15164-speed/enhance_cudnn
parent
4ae23cc3c5
commit
8ad672a287
@ -0,0 +1,45 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
std::unique_ptr<ir::Graph> SyncBatchNormPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
VLOG(3) << "Use synchronous batch norm";
|
||||
for (const Node* n : graph->Nodes()) {
|
||||
if (n->IsOp()) {
|
||||
auto* op = n->Op();
|
||||
if (op->Type() == "batch_norm") {
|
||||
op->SetType("sync_batch_norm");
|
||||
}
|
||||
if (op->Type() == "batch_norm_grad") {
|
||||
op->SetType("sync_batch_norm_grad");
|
||||
}
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(sync_batch_norm_pass, paddle::framework::ir::SyncBatchNormPass);
|
@ -0,0 +1,32 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class SyncBatchNormPass : public Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetAttr("name", name);
|
||||
op->SetInput("X", inputs);
|
||||
op->SetOutput("Out", outputs);
|
||||
}
|
||||
|
||||
// (a, conv_w)->conv2d->b
|
||||
// (b, bn_scale, bn_bias, mean, var)->batch_norm
|
||||
// ->(c, mean, var, save_mean, save_inv_var)
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : std::vector<std::string>({"a", "conv_w", "b", "bn_scale",
|
||||
"bn_bias", "mean", "var", "c",
|
||||
"save_mean", "save_inv_var"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
if (v == "conv_w" || v == "bn_scale" || v == "bn_bias" || v == "mean" ||
|
||||
v == "var") {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
SetOp(&prog, "conv2d", "conv", std::vector<std::string>({"a", "conv_w"}),
|
||||
std::vector<std::string>({"b"}));
|
||||
SetOp(&prog, "batch_norm", "bn",
|
||||
std::vector<std::string>({"b", "bn_scale", "bn_bias", "mean", "var"}),
|
||||
std::vector<std::string>(
|
||||
{"c", "mean", "var", "save_mean", "save_inv_var"}));
|
||||
return prog;
|
||||
}
|
||||
|
||||
TEST(IsTestPass, basic) {
|
||||
auto prog = BuildProgramDesc();
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("sync_batch_norm_pass");
|
||||
|
||||
graph = pass->Apply(std::move(graph));
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp()) {
|
||||
auto* op = node->Op();
|
||||
auto op_name = boost::get<std::string>(op->GetAttr("name"));
|
||||
if (op_name == "bn") {
|
||||
ASSERT_EQ(op->Type(), "sync_batch_norm");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(sync_batch_norm_pass);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,20 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/batch_norm_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(sync_batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
||||
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
|
||||
REGISTER_OPERATOR(sync_batch_norm_grad, ops::BatchNormGradOp);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,159 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import os
|
||||
import six
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import compiler
|
||||
|
||||
|
||||
class TestSyncBatchNormOpTraining(unittest.TestCase):
|
||||
def setUp(self):
|
||||
#self.dtype = np.float32
|
||||
self.dtype = np.float64
|
||||
self.N = 32
|
||||
self.C = 16
|
||||
self.H = 64
|
||||
self.W = 32
|
||||
self.dshape = [self.N, self.C, self.H, self.W]
|
||||
|
||||
def build_program(self,
|
||||
place,
|
||||
layout,
|
||||
seed,
|
||||
sync_bn=False,
|
||||
only_forward=False):
|
||||
main = fluid.Program()
|
||||
startup = fluid.Program()
|
||||
main.random_seed = seed
|
||||
startup.random_seed = seed
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, startup):
|
||||
data = fluid.layers.data(
|
||||
name='input',
|
||||
shape=self.dshape,
|
||||
dtype=self.dtype,
|
||||
append_batch_size=False)
|
||||
conv = fluid.layers.conv2d(
|
||||
input=data,
|
||||
num_filters=32,
|
||||
filter_size=1,
|
||||
param_attr=fluid.ParamAttr(name='conv2d_weight'),
|
||||
bias_attr=False,
|
||||
use_cudnn=False)
|
||||
bn = fluid.layers.batch_norm(
|
||||
conv,
|
||||
param_attr=fluid.ParamAttr(name='bn_scale'),
|
||||
bias_attr=fluid.ParamAttr(name='bn_bias'),
|
||||
moving_mean_name='bn_moving_mean',
|
||||
moving_variance_name='bn_moving_variance',
|
||||
data_layout=layout,
|
||||
is_test=only_forward)
|
||||
sigmoid = fluid.layers.sigmoid(bn)
|
||||
out = fluid.layers.reduce_sum(sigmoid)
|
||||
if not sync_bn:
|
||||
out = out / core.get_cuda_device_count()
|
||||
if not only_forward:
|
||||
sgd_opt = fluid.optimizer.SGD(learning_rate=0.0)
|
||||
sgd_opt.backward(out)
|
||||
return main, startup, [out, conv, bn]
|
||||
|
||||
def compare(self, place, layout, only_forward):
|
||||
seed = 10
|
||||
os.environ['FLAGS_cudnn_deterministic'] = "1"
|
||||
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
|
||||
# Single-GPU, N = 32 per GPU
|
||||
main, startup, outs = self.build_program(place, layout, seed, False,
|
||||
only_forward)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup)
|
||||
fetch_names = [v.name for v in outs] + [
|
||||
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
|
||||
]
|
||||
if not only_forward:
|
||||
others = [
|
||||
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
|
||||
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD'
|
||||
]
|
||||
fetch_names += others
|
||||
bn_fetches = exe.run(program=main,
|
||||
feed={'input': data},
|
||||
fetch_list=fetch_names)
|
||||
|
||||
#####################################################################
|
||||
# Multi-GPUs, self.N / core.get_cuda_device_count() per GPU
|
||||
main, startup, outs = self.build_program(place, layout, seed, True,
|
||||
only_forward)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup)
|
||||
fetch_names = [v.name for v in outs] + [
|
||||
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
|
||||
]
|
||||
if not only_forward:
|
||||
others = [
|
||||
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
|
||||
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD'
|
||||
]
|
||||
fetch_names += others
|
||||
for nm in fetch_names:
|
||||
fv = fluid.framework._get_var(str(nm), program=main)
|
||||
fv.persistable = True
|
||||
build_strategy = fluid.BuildStrategy()
|
||||
build_strategy.sync_batch_norm = True
|
||||
build_strategy.enable_inplace = False
|
||||
build_strategy.memory_optimize = False
|
||||
comp_prog = compiler.CompiledProgram(main).with_data_parallel(
|
||||
outs[0].name if not only_forward else None,
|
||||
build_strategy=build_strategy)
|
||||
sync_bn_fetches = exe.run(program=comp_prog,
|
||||
feed={'input': data},
|
||||
fetch_list=fetch_names)
|
||||
|
||||
for i in six.moves.xrange(1, len(sync_bn_fetches)):
|
||||
bn_val = bn_fetches[i]
|
||||
sync_bn_val = sync_bn_fetches[i]
|
||||
if sync_bn_val.shape != bn_val.shape:
|
||||
sync_bn_val = sync_bn_val[:bn_val.shape[0]]
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
bn_val, sync_bn_val, atol=1e-3),
|
||||
"Output (" + fetch_names[i] + ") has diff. \n" + "\nBN " +
|
||||
str(bn_val) + "\n" + "Sync BN " + str(sync_bn_val))
|
||||
|
||||
def test_train(self):
|
||||
if not core.is_compiled_with_cuda():
|
||||
return
|
||||
|
||||
places = [core.CUDAPlace(0)]
|
||||
for place in places:
|
||||
for layout in ["NCHW", "NHWC"]:
|
||||
self.compare(place, layout, False)
|
||||
|
||||
def test_infer(self):
|
||||
if not core.is_compiled_with_cuda():
|
||||
return
|
||||
|
||||
places = [core.CUDAPlace(0)]
|
||||
for place in places:
|
||||
for layout in ["NCHW", "NHWC"]:
|
||||
self.compare(place, layout, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue