Rebuild group automatically in dynamic graph distributed (#29255)

* add tensor_indices in AssignGroupBySize

* add rebuild group in reducer
revert-31562-mean
ShenLiang 5 years ago committed by GitHub
parent 3a0558339d
commit 2ef9e0e23c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -86,6 +86,8 @@ class Group {
std::vector<framework::Tensor> dense_tensors_;
std::vector<size_t> length_;
int64_t all_length_{0};
// Global indices of participating variables in the group
std::vector<size_t> variable_indices_;
@ -97,53 +99,15 @@ class Group {
framework::proto::VarType::Type dtype_;
// context is used to select the stream for concat
void ConcatTensors(const platform::CUDADeviceContext& context) {
switch (dtype_) {
case framework::proto::VarType::FP16:
ConcatTensorsForAllReduce<platform::float16>(context, dense_tensors_,
&dense_contents_);
break;
case framework::proto::VarType::FP32:
ConcatTensorsForAllReduce<float>(context, dense_tensors_,
&dense_contents_);
break;
case framework::proto::VarType::FP64:
ConcatTensorsForAllReduce<double>(context, dense_tensors_,
&dense_contents_);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for "
"allreduce.",
framework::DataTypeToString(dtype_)));
}
}
void ConcatTensors(const platform::CUDADeviceContext& context);
// context is used to select the stream for split
void SplitTensors(const platform::CUDADeviceContext& context) {
switch (dtype_) {
case framework::proto::VarType::FP16:
SplitTensorsForAllReduce<platform::float16>(context, &dense_contents_,
&dense_tensors_);
break;
case framework::proto::VarType::FP32:
SplitTensorsForAllReduce<float>(context, &dense_contents_,
&dense_tensors_);
break;
case framework::proto::VarType::FP64:
SplitTensorsForAllReduce<double>(context, &dense_contents_,
&dense_tensors_);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allreduce.",
framework::DataTypeToString(dtype_)));
}
}
void SplitTensors(const platform::CUDADeviceContext& context);
friend std::ostream& operator<<(std::ostream&, const Group&);
};
struct VariableIndex {
struct VariableLocator {
// record the index in groups_
size_t group_index;
size_t inside_group_index;
@ -155,22 +119,21 @@ class Reducer {
const std::vector<std::shared_ptr<imperative::VarBase>>& vars,
const std::vector<std::vector<size_t>>& group_indices,
const std::vector<bool>& is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx);
std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t>& group_size_limits);
virtual ~Reducer() {}
void InitializeGroups(const std::vector<std::vector<size_t>>& group_indices);
int64_t InitializeDenseGroups(const std::vector<size_t>& variable_indices_,
Group* p_group);
void InitializeDenseGroups(const std::vector<size_t>& variable_indices_,
Group* p_group);
void PrepareForBackward();
void AddDistHook(VariableWrapper* var_warpper,
const VariableIndex& var_index);
void AddDistHook(VariableWrapper* var_warpper, size_t var_index);
void MarkVariableReady(const VariableIndex& var_index,
VariableWrapper* var_warpper);
void MarkVariableReady(size_t var_index, VariableWrapper* var_warpper);
void MarkGroupReady(size_t group_index);
@ -178,15 +141,21 @@ class Reducer {
void ReleaseReducer();
std::vector<std::vector<size_t>> RebuildGruops();
void CreateGroupEvents(int group_num);
// Reducer Singleton
static std::shared_ptr<Reducer> SetInstance(
const std::vector<std::shared_ptr<imperative::VarBase>>& vars,
const std::vector<std::vector<size_t>>& group_indices,
const std::vector<bool>& is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx) {
std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t>& group_size_limits) {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::imperative::Reducer(
vars, group_indices, is_sparse_gradient, parallel_ctx));
vars, group_indices, is_sparse_gradient, parallel_ctx,
group_size_limits));
}
return s_instance_;
}
@ -208,17 +177,26 @@ class Reducer {
std::once_flag once_flag_;
std::vector<bool> is_sparse_gradient_;
std::shared_ptr<imperative::ParallelContext> parallel_ctx_;
std::vector<VariableLocator> variable_locators_;
// Following variables are to help sync stream
std::vector<std::shared_ptr<platform::CudaEventObject>> events_;
std::shared_ptr<platform::CudaEventObject> comm_enent_;
cudaStream_t compute_stream_;
cudaStream_t comm_stream_;
// Following variables are to help rebuild group
bool has_rebuilt_group_{false};
std::vector<std::shared_ptr<imperative::VarBase>> rebuild_vars_;
std::vector<int64_t> rebuild_var_indices_;
const std::vector<size_t> group_size_limits_;
};
std::vector<std::vector<size_t>> AssignGroupBySize(
const std::vector<std::shared_ptr<imperative::VarBase>>& tensors,
const std::vector<bool>& is_sparse_gradient,
const std::vector<size_t>& group_size_limits);
const std::vector<size_t>& group_size_limits,
const std::vector<int64_t>& tensor_indices = {});
#endif
} // namespace imperative

@ -12,3 +12,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy)
if (WITH_NCCL)
cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy)
endif()

@ -0,0 +1,66 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <ostream>
#include <sstream>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/imperative/reducer.h"
#endif
namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL)
TEST(TestGroup, TestPrintGroupMessage) {
Group group;
std::stringstream stream1, stream2;
stream1 << group;
ASSERT_STREQ(stream1.str().c_str(),
"numul: 0 ;is_sparse: 0 ;var number: 0\n[]\n");
std::vector<size_t> vars;
size_t vars_num = 102;
for (size_t i = 0; i < vars_num; ++i) {
vars.push_back(i);
}
group.variable_indices_ = vars;
group.all_length_ = 102;
group.is_sparse_ = false;
std::string head = "numul: 102 ;is_sparse: 0 ;var number: 102\n";
head = head + "[";
auto begin = vars.begin();
auto end = vars.end();
for (int i = 0; begin != end && i < 100; ++i, ++begin) {
if (i > 0) head += ' ';
head += std::to_string(*begin);
}
if (begin != end) {
head += " ...";
}
head += "]\n";
stream2 << group;
ASSERT_STREQ(stream2.str().c_str(), head.c_str());
}
#endif
} // namespace imperative
} // namespace paddle

@ -1289,9 +1289,11 @@ void BindImperative(py::module *m_ptr) {
[](const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
const std::vector<std::vector<size_t>> &group_indices,
const std::vector<bool> &is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx) {
std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t> &group_size_limits) {
return imperative::Reducer::SetInstance(
vars, group_indices, is_sparse_gradient, parallel_ctx);
vars, group_indices, is_sparse_gradient, parallel_ctx,
group_size_limits);
}))
.def("prepare_for_backward", &imperative::Reducer::PrepareForBackward,
py::call_guard<py::gil_scoped_release>());
@ -1299,6 +1301,7 @@ void BindImperative(py::module *m_ptr) {
m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"),
py::arg("is_sparse_gradient"),
py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
py::arg("tensor_indices") = std::vector<int64_t>{},
py::call_guard<py::gil_scoped_release>());
#endif
}

@ -18,7 +18,6 @@ from paddle.fluid.framework import Variable, set_flags, core
from paddle.fluid.wrapped_decorator import wrap_decorator
import google.protobuf.text_format
import google.protobuf
from paddle.fluid.framework import dygraph_only
__all__ = ["DistributedStrategy"]

@ -441,10 +441,11 @@ class DataParallel(layers.Layer):
"ParallelContext must be initialized before. You should use init_parallel_env() before" \
"constructing the DataParallel."
self._reducer = core.Reducer(trainable_parameters,
list(reversed(self.group_indices)),
is_sparse_gradient,
parallel_helper.__parallel_ctx__clz__)
self._reducer = core.Reducer(
trainable_parameters,
list(reversed(self.group_indices)), is_sparse_gradient,
parallel_helper.__parallel_ctx__clz__,
[self.last_comm_buffer_size, self.comm_buffer_size])
def forward(self, *inputs, **kwargs):
if self._strategy.nranks > 1:

@ -155,6 +155,30 @@ class TestDataParallelGroup(unittest.TestCase):
var_list, [True, False, False, False, False, True], [200, 400])
self.assertEqual([[0], [1], [2], [3], [4], [5]], res)
def test_construct_group8(self):
# one dtype & one limit capability & have tensor_indices
var_list = []
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(
self.create_varbase(core.VarDesc.VarType.FP32, [2, 100]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
res = core.assign_group_by_size(var_list, [False, False, False, False],
[400], [3, 0, 1, 2])
self.assertEqual([[3, 0], [1], [2]], res)
def test_construct_group9(self):
# one dtype & one limit capability & have tensor_indices
var_list = []
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(
self.create_varbase(core.VarDesc.VarType.FP32, [2, 1000]))
res = core.assign_group_by_size(var_list, [False, False, False, True],
[300], [1, 0, 2, 3])
self.assertEqual([[1, 0], [3], [2]], res)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save