Optimization grad merge performance (#29784)
parent
e891f4da1b
commit
ee16006b5d
@ -0,0 +1,132 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "paddle/fluid/framework/details/container_cast.h"
|
||||
#include "paddle/fluid/framework/details/reduce_and_gather.h"
|
||||
#include "paddle/fluid/framework/details/variable_visitor.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
|
||||
#ifdef PADDLE_WITH_NCCL
|
||||
DECLARE_bool(sync_nccl_allreduce);
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
#if defined(PADDLE_WITH_NCCL)
|
||||
GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle(
|
||||
ir::Node *node, const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::NCCLCommunicator *ctxs)
|
||||
: AllReduceOpHandle(node, local_scopes, places, ctxs),
|
||||
grad_merge_cond_name_(grad_merge_cond_name) {}
|
||||
#elif defined(PADDLE_WITH_XPU_BKCL)
|
||||
GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle(
|
||||
ir::Node *node, const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::BKCLCommunicator *ctxs)
|
||||
: AllReduceOpHandle(node, local_scopes, places, ctxs),
|
||||
grad_merge_cond_name_(grad_merge_cond_name) {}
|
||||
#else
|
||||
GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle(
|
||||
ir::Node *node, const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &grad_merge_cond_name)
|
||||
: AllReduceOpHandle(node, local_scopes, places),
|
||||
grad_merge_cond_name_(grad_merge_cond_name) {}
|
||||
#endif
|
||||
|
||||
void GradMergeAllReduceOpHandle::RunImpl() {
|
||||
PADDLE_ENFORCE_GT(local_scopes_.size(), 0,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The number of local scope should be > 0, but got %zu.",
|
||||
local_scopes_.size()));
|
||||
|
||||
auto *local_scope = local_exec_scopes_[0];
|
||||
auto cond_var = local_scope->FindVar(grad_merge_cond_name_);
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
cond_var, platform::errors::NotFound("Variable %s is not found in scope.",
|
||||
cond_var));
|
||||
bool cond = *cond_var->Get<LoDTensor>().data<bool>();
|
||||
|
||||
if (cond) {
|
||||
AllReduceOpHandle::RunImpl();
|
||||
}
|
||||
}
|
||||
|
||||
std::string GradMergeAllReduceOpHandle::Name() const {
|
||||
return "grad_merge_all_reduce";
|
||||
}
|
||||
|
||||
#if defined(PADDLE_WITH_NCCL)
|
||||
FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle(
|
||||
ir::Node *node, const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::NCCLCommunicator *ctxs)
|
||||
: FusedAllReduceOpHandle(node, local_scopes, places, num_of_all_reduce,
|
||||
ctxs),
|
||||
grad_merge_cond_name_(grad_merge_cond_name) {}
|
||||
#elif defined(PADDLE_WITH_XPU_BKCL)
|
||||
FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle(
|
||||
ir::Node *node, const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::BKCLCommunicator *ctxs)
|
||||
: FusedAllReduceOpHandle(node, local_scopes, places, num_of_all_reduce,
|
||||
ctxs),
|
||||
grad_merge_cond_name_(grad_merge_cond_name) {}
|
||||
#else
|
||||
FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle(
|
||||
ir::Node *node, const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
|
||||
const std::string &grad_merge_cond_name)
|
||||
: FusedAllReduceOpHandle(node, local_scopes, places, num_of_all_reduce),
|
||||
grad_merge_cond_name_(grad_merge_cond_name) {}
|
||||
#endif
|
||||
|
||||
void FusedGradMergeAllReduceOpHandle::RunImpl() {
|
||||
PADDLE_ENFORCE_GT(local_scopes_.size(), 0,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The number of local scope should be > 0, but got %zu.",
|
||||
local_scopes_.size()));
|
||||
|
||||
auto *local_scope = local_exec_scopes_[0];
|
||||
auto cond_var = local_scope->FindVar(grad_merge_cond_name_);
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
cond_var, platform::errors::NotFound("Variable %s is not found in scope.",
|
||||
cond_var));
|
||||
bool cond = *cond_var->Get<LoDTensor>().data<bool>();
|
||||
|
||||
if (cond) {
|
||||
VLOG(10) << "run fused grad merge all reduce";
|
||||
FusedAllReduceOpHandle::RunImpl();
|
||||
}
|
||||
}
|
||||
|
||||
std::string FusedGradMergeAllReduceOpHandle::Name() const {
|
||||
return "fused_grad_merge_all_reduce";
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,111 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
class Node;
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
namespace platform {
|
||||
class NCCLCommunicator;
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
||||
#if defined(PADDLE_WITH_NCCL)
|
||||
#include "paddle/fluid/framework/details/nccl_op_handle.h"
|
||||
#include "paddle/fluid/platform/nccl_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class GradMergeAllReduceOpHandle : public AllReduceOpHandle {
|
||||
public:
|
||||
#if defined(PADDLE_WITH_NCCL)
|
||||
GradMergeAllReduceOpHandle(ir::Node *node,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::NCCLCommunicator *ctxs);
|
||||
#elif defined(PADDLE_WITH_XPU_BKCL)
|
||||
GradMergeAllReduceOpHandle(ir::Node *node,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::BKCLCommunicator *ctxs);
|
||||
#else
|
||||
GradMergeAllReduceOpHandle(ir::Node *node,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &grad_merge_cond_name);
|
||||
#endif
|
||||
std::string Name() const override;
|
||||
|
||||
std::string GradMergeCondName() { return grad_merge_cond_name_; }
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
|
||||
private:
|
||||
std::string grad_merge_cond_name_;
|
||||
};
|
||||
|
||||
class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
|
||||
public:
|
||||
#if defined(PADDLE_WITH_NCCL)
|
||||
FusedGradMergeAllReduceOpHandle(ir::Node *node,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const size_t num_of_all_reduce,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::NCCLCommunicator *ctxs);
|
||||
#elif defined(PADDLE_WITH_XPU_BKCL)
|
||||
FusedGradMergeAllReduceOpHandle(ir::Node *node,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const size_t num_of_all_reduce,
|
||||
const std::string &grad_merge_cond_name,
|
||||
const platform::BKCLCommunicator *ctxs);
|
||||
#else
|
||||
FusedGradMergeAllReduceOpHandle(ir::Node *node,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const size_t num_of_all_reduce,
|
||||
const std::string &grad_merge_cond_name);
|
||||
#endif
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
|
||||
private:
|
||||
std::string grad_merge_cond_name_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from test_dist_base import TestDistRunnerBase, runtime_main
|
||||
from dist_mnist import cnn_model
|
||||
|
||||
DTYPE = "float32"
|
||||
paddle.dataset.mnist.fetch()
|
||||
|
||||
# Fix seed for test
|
||||
fluid.default_startup_program().random_seed = 1
|
||||
fluid.default_main_program().random_seed = 1
|
||||
|
||||
|
||||
class TestDistMnist2x2(TestDistRunnerBase):
|
||||
def get_model(self, batch_size=2):
|
||||
# Input data
|
||||
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
|
||||
# Train program
|
||||
predict = cnn_model(images)
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
# Evaluator
|
||||
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
|
||||
batch_acc = fluid.layers.accuracy(
|
||||
input=predict, label=label, total=batch_size_tensor)
|
||||
|
||||
inference_program = fluid.default_main_program().clone()
|
||||
# Optimization
|
||||
opt = fluid.optimizer.MomentumOptimizer(
|
||||
learning_rate=0.001, momentum=0.9)
|
||||
opt = fluid.optimizer.GradientMergeOptimizer(opt, 2)
|
||||
|
||||
# Reader
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.mnist.test(), batch_size=batch_size)
|
||||
test_reader = paddle.batch(
|
||||
paddle.dataset.mnist.test(), batch_size=batch_size)
|
||||
opt.minimize(avg_cost)
|
||||
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestDistMnist2x2)
|
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
from test_dist_base import TestDistBase
|
||||
|
||||
flag_name = os.path.splitext(__file__)[0]
|
||||
|
||||
|
||||
class TestDistMnistGradMerge(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = True
|
||||
self._use_reduce = False
|
||||
self._nccl2_mode = True
|
||||
|
||||
def test_dist_train(self):
|
||||
import paddle.fluid as fluid
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
self.check_with_place(
|
||||
"dist_mnist_gradient_merge.py",
|
||||
delta=1e-5,
|
||||
check_error_log=True,
|
||||
log_name=flag_name)
|
||||
|
||||
|
||||
class TestDistMnistGradMergeNoFuse(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = True
|
||||
self._use_reduce = False
|
||||
self._nccl2_mode = True
|
||||
self._fuse_all_reduce = False
|
||||
|
||||
def test_dist_train(self):
|
||||
import paddle.fluid as fluid
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
self.check_with_place(
|
||||
"dist_mnist_gradient_merge.py",
|
||||
delta=1e-5,
|
||||
check_error_log=True,
|
||||
log_name=flag_name + "_no_fuse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue