Add DGC(Deep Gradient Compression) interface. (#15841)

move-code
gongweibao 6 years ago committed by GitHub
parent b1d2605152
commit eb83abeac3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -193,6 +193,12 @@ if(WITH_GPU)
include(tensorrt)
include(anakin_subgraph)
endif()
if(WITH_GPU AND NOT WIN32)
message(STATUS "add dgc lib.")
include(external/dgc)
endif()
if(WITH_MKL OR WITH_MKLML)
include(external/anakin)
elseif()

@ -0,0 +1,42 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc")
SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc")
SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE)
SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE)
INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
ExternalProject_Add(
extern_dgc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/PaddlePaddle/Fleet"
GIT_TAG "2d04dc3800cdd0601f1b65d547dabcc60b0cf9dc"
SOURCE_DIR "${DGC_SOURCES_DIR}"
CONFIGURE_COMMAND ""
BUILD_COMMAND cd collective && make -j
INSTALL_COMMAND mkdir -p ${DGC_INSTALL_DIR}/lib/ ${DGC_INCLUDE_DIR}/dgc
&& cp ${DGC_SOURCES_DIR}/collective/build/lib/libdgc.a ${DGC_LIBRARIES}
&& cp ${DGC_SOURCES_DIR}/collective/build/include/dgc.h ${DGC_INCLUDE_DIR}/dgc/
BUILD_IN_SOURCE 1
)
ADD_LIBRARY(dgc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)

@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32)
)
endif ()
if (WITH_GPU AND NOT WIN32)
set(dgc_dir "${FLUID_INSTALL_DIR}/third_party/install/dgc")
copy(dgc_lib
SRCS ${DGC_INSTALL_DIR}/lib ${DGC_INSTALL_DIR}/include
DSTS ${dgc_dir} ${dgc_dir}
DEPS dgc)
endif()
if (WITH_MKLDNN)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn")
copy(mkldnn_lib

@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()

@ -483,6 +483,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel
paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))

@ -23,7 +23,7 @@ endif()
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
dynload_cuda variable_visitor dgc)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE)

@ -86,7 +86,8 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
}
}
VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl;
VLOG(10) << "dist_ops size:" << dist_ops.size()
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) {
@ -99,6 +100,10 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
auto l_it = vars.find(i0->name());
auto r_it = vars.find(i1->name());
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(),
"can't find var's name %s and %s in opdesc", i0->name(),
i1->name());
if (l_it->second < r_it->second) return true;
if (l_it->second == r_it->second) {

@ -16,6 +16,13 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue:
@ -33,11 +40,14 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
const platform::NCCLContextMap *ctxs,
bool is_encoded, int nranks)
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
nccl_ctxs_(ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> ins;
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var);
auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in);
auto *out = local_scope->FindVar(out_var_handles[i]->name())
->GetMutable<LoDTensor>();
outs.emplace_back(out);
if (k < 0) {
k = GetKValue(in_var_handles[i]->name());
}
}
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t in_numel = 0;
size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
PADDLE_ENFORCE(in_numel % 2 == 0);
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm,
stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
}
}
}
int AllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
bool AllReduceOpHandle::IsEncoded() {
if (!is_encoded_) {
return false;
}
auto counter_name = g_dgc_counter_name;
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
step_var);
}
float count = *count_var->Get<LoDTensor>().data<float>();
float step = *step_var->Get<LoDTensor>().data<float>();
if (static_cast<int>(count) < static_cast<int>(step)) {
VLOG(10) << "in all_reduce currentstep:" << count
<< " < rampup_begin_step:" << step
<< " so not use sparse all reduce";
return false;
}
return true;
}
#else
bool AllReduceOpHandle::IsEncoded() { return false; }
#endif
void AllReduceOpHandle::RunImpl() {
if (!IsEncoded()) {
RunImplNormal();
return;
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
RunImplEncoded();
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
}
void AllReduceOpHandle::RunImplNormal() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() {
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal.");
}
@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() {
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device

@ -28,11 +28,19 @@ namespace paddle {
namespace framework {
namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
constexpr char g_dgc_counter_name[] = "__g_dgc_counter__";
constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__";
constexpr char g_dgc_encoded[] = "__dgc_encoded__";
constexpr char g_dgc_k[] = "__dgc_k__";
#endif
struct AllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1);
#else
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunImplEncoded();
const platform::NCCLContextMap *nccl_ctxs_;
bool is_encoded_{false};
int nranks_{-1};
int GetKValue(const std::string &grad_name);
#endif
void RunImplNormal();
bool IsEncoded();
};
} // namespace details

@ -32,6 +32,7 @@
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace framework {
@ -209,7 +210,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name
<< " op_type " << node->Op()->Type();
if (NeedCollectiveForGrad(g_name, sorted_ops)) {
InsertCollectiveOp(&result, p_name, g_name);
}
@ -414,8 +416,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
CreateOpHandleIOs(result, node, dev_id);
}
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
const std::string &og,
bool is_encoded) const {
OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&](
@ -424,7 +427,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_));
scopes, places, nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) *
places_.size()));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
@ -446,12 +451,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad);
VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString();
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, places_[i]);
vars.emplace_back(var);
op_handle->AddOutput(var);
VLOG(10) << "all_reduce_op_handle add output " << og
<< ", handle:" << var->DebugString();
}
}
@ -941,6 +949,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id;
}
bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
auto u_name = p_name + "__dgc_u__";
auto it = all_vars_.find(u_name);
if (it == all_vars_.end()) {
VLOG(10) << "can't find u_name, so it's not encoded:" << u_name;
return false;
}
return true;
}
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
const std::string &g_name) const {
@ -956,7 +975,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateAllReduceOp(result, g_name, IsEncoded(p_name));
#else
PADDLE_ENFORCE(false, "Compiled withoud cuda!");
#endif
}
break;
default:

@ -75,7 +75,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og,
bool is_encoded = false) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
@ -171,6 +172,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
bool IsEncoded(const std::string &p_name) const;
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();

@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const {
std::stringstream ss;
ss << name_ << ":" << place_;
ss << "name:" << name_ << ", place:" << place_ << ", version:" << version_
<< ", scope_idx:" << scope_idx_;
return ss.str();
}

@ -373,6 +373,11 @@ std::vector<std::string> OpDesc::AttrNames() const {
return retv;
}
void OpDesc::RemoveAttr(const std::string &name) {
attrs_.erase(name);
need_update_ = true;
}
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
@ -644,6 +649,7 @@ void OpDesc::CheckAttrs() {
// not by users.
return;
}
VLOG(10) << "begin to check attribute of " << Type();
checker->Check(&attrs_);
}

@ -72,6 +72,7 @@ class OpDesc {
std::vector<std::string> AttrNames() const;
void SetAttr(const std::string &name, const Attribute &v);
void RemoveAttr(const std::string &name);
void SetBlockAttr(const std::string &name, BlockDesc *block);

@ -1110,8 +1110,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
Type(), DataTypeToString(data_type), DataTypeToString(tmp));
"DataType of Paddle Op %s %s must be the same. Get (%d) != (%d)",
Type(), input.first, DataTypeToString(data_type),
DataTypeToString(tmp));
data_type = tmp;
}
}

@ -49,6 +49,11 @@ set(SHARED_INFERENCE_SRCS
${mkldnn_quantizer_src}
${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc)
# FIXME(gongwb): hidden libdgc.a
if(WITH_GPU AND NOT WIN32)
set(fluid_modules ${fluid_modules} dgc)
endif()
if(WIN32)
sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array
analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder)

@ -48,7 +48,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
@ -72,6 +72,12 @@ endif()
set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
if (WITH_GPU AND NOT WIN32)
op_library(dgc_op DEPS dgc)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(dgc);\n")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)

@ -14,69 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/clip_by_norm_op.h"
namespace paddle {
namespace operators {
class ClipByNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm Operator.
This operator limits the L2 norm of the input $X$ within $max\_norm$.
If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be
the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will
be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as
shown in the following formula:
$$
Out = \\frac{max\\_norm * X}{norm(X)},
$$
where $norm(X)$ represents the L2 norm of $X$.
Examples:
.. code-block:: python
data = fluid.layer.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.clip_by_norm(
x=data, max_norm=0.5)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CPUDeviceContext, float>);

@ -83,5 +83,59 @@ class ClipByNormKernel : public framework::OpKernel<T> {
}
};
class ClipByNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm Operator.
This operator limits the L2 norm of the input $X$ within $max\_norm$.
If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be
the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will
be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as
shown in the following formula:
$$
Out = \\frac{max\\_norm * X}{norm(X)},
$$
where $norm(X)$ represents the L2 norm of $X$.
Examples:
.. code-block:: python
data = fluid.layer.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.clip_by_norm(
x=data, max_norm=0.5)
)DOC");
}
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,67 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/dgc_clip_by_norm_op.h"
namespace paddle {
namespace operators {
class DGCClipByNormOp : public ClipByNormOp {
public:
using ClipByNormOp::ClipByNormOp;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("current_step"),
"current_step should be set.");
return ClipByNormOp::InferShape(ctx);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "current_step") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class DGCClipByNormOpMaker : public ClipByNormOpMaker {
public:
void Make() override {
AddInput("current_step", "(Tensor) Current step.");
AddAttr<float>("rampup_begin_step",
"(float, -1.0)"
"The period when begin k_select.")
.SetDefault(-1.0);
return ClipByNormOpMaker::Make();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc_clip_by_norm, ops::DGCClipByNormOp,
ops::DGCClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
dgc_clip_by_norm,
ops::DGCClipByNormKernel<paddle::platform::CPUDeviceContext, float>);

@ -0,0 +1,20 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_clip_by_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
dgc_clip_by_norm,
ops::DGCClipByNormKernel<paddle::platform::CUDADeviceContext, float>);

@ -0,0 +1,46 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/clip_by_norm_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
if (static_cast<int>(rampup_begin_step) >= 0) {
auto current_step_tensor =
context.Input<framework::Tensor>("current_step");
auto* current_step = current_step_tensor->data<T>();
if (static_cast<int>(*current_step) <
static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc_clip_by_norm";
return;
}
}
return ClipByNormKernel<DeviceContext, T>::Compute(context);
};
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,138 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class DGCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("current_step"),
"Input(current_step) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("U_out"),
"Output(U_out) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("V_out"),
"Output(V_out) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("k"),
"Output(k) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("EncodeGrad"),
"Output(EncodeGrad) of DGCop should not be null.");
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "current_step" || var_name == "rampup_step" ||
var_name == "k") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class DGCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("U", "(Tensor) Middle tensor of DGC");
AddInput("V", "(Tensor) Middle tensor of DGC");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("current_step", "(Tensor) Current step.");
AddOutput("U_out",
"(Tensor) "
"Output encoded gradient");
AddOutput("V_out",
"(Tensor) "
"Output encoded gradient");
AddOutput("EncodeGrad",
"(Tensor) "
"Output encoded gradient");
AddOutput("Grad_out",
"(Tensor) "
"Output grad gradient");
AddOutput("k",
"(Tensor) "
"Output top-k value");
AddAttr<float>("m",
"(float, 0.9) "
"The momentum of learning rate.")
.SetDefault(0.9);
AddAttr<bool>("use_nesterov",
"(bool, true)"
"The momentum of learning rate.")
.SetDefault(true);
AddAttr<std::vector<float>>("sparsity",
"(vecotr, float)"
"The period sparsity of k_select.");
AddAttr<float>("rampup_begin_step",
"(float, 0.0)"
"The period when begin k_select.")
.SetDefault(0.0);
AddAttr<float>("rampup_step",
"(float, 0.0)"
"The period when begin k_select.");
AddComment(R"DOC(
Original paper is https://arxiv.org/abs/1712.01887
DGC reduce the communication bandwidth by sending only the important gradients (sparse update):\
only gradients larger than a threshold are transmitted.
To avoid losing information, DGC accumulate the rest of the gradients locally.
Eventually, these gradients become large enough to be transmitted.
Thus, DGC send the large gradients immediately but eventually send all of the gradients over time.
To ensure no loss of accuracy, DGC employs momentum correc-tionandlocal gradient clipping on top of the gradient sparsification to maintain model performance.
DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication.
This optimizer will do two things:
1. Compress the gradient by get TopK import value from tensor \
and use it for allreduce to reduce network bandwidth.
2. Call momentum to optimize on the cost.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc, ops::DGCOp, ops::DGCOpMaker);

@ -0,0 +1,20 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
dgc, ops::DGCOpKernel<paddle::platform::CUDADeviceContext, float>);

@ -0,0 +1,132 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
namespace paddle {
namespace operators {
inline float get_period_sparcity(const std::vector<float>& sparsity,
float cur_step, float rampup_steps) {
PADDLE_ENFORCE(static_cast<int>(cur_step) >= 0);
size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
if (idx >= sparsity.size()) {
return 0.999;
}
PADDLE_ENFORCE(idx < sparsity.size());
return sparsity[idx];
}
template <typename DeviceContext, typename T>
class DGCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto u = ctx.Input<framework::Tensor>("U");
auto v = ctx.Input<framework::Tensor>("V");
auto g = ctx.Input<framework::Tensor>("Grad");
// attrs
float m = ctx.Attr<float>("m");
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto sparsity = ctx.Attr<std::vector<float>>("sparsity");
auto rampup_begin_step = ctx.Attr<float>("rampup_begin_step");
auto rampup_step = ctx.Attr<float>("rampup_step");
// current step
auto current_step_tensor = ctx.Input<framework::Tensor>("current_step");
const float* current_step = current_step_tensor->data<float>();
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc";
return;
}
float ratio =
1 - get_period_sparcity(sparsity, static_cast<float>(*current_step),
rampup_step);
PADDLE_ENFORCE(ratio > 0.0 && ratio < 1.0);
int k = static_cast<int>(g->numel() * ratio);
VLOG(10) << "m:" << m << ", use_nesterov:" << use_nesterov
<< ", rampup_begin_step:" << rampup_begin_step
<< ", rampup_step:" << rampup_step
<< ", current_step:" << *current_step << ", ratio:" << ratio
<< ", k:" << k;
auto k_out = ctx.Output<framework::Tensor>("k");
T* k_out_data = k_out->data<T>();
*k_out_data = k;
auto u_out = ctx.Output<framework::Tensor>("U_out");
auto v_out = ctx.Output<framework::Tensor>("V_out");
auto encode_grad_out = ctx.Output<framework::Tensor>("EncodeGrad");
// FIXME(gongwb): use cublas.
auto u_out_e = framework::EigenVector<T>::Flatten(*u_out);
auto u_e = framework::EigenVector<T>::Flatten(*u);
auto g_e = framework::EigenVector<T>::Flatten(*g);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto& eigen_ctx = *dev_ctx.eigen_device();
if (use_nesterov) {
// u = m * (u + g)
u_out_e.device(eigen_ctx) = m * (u_e + g_e);
// v = u + v + g
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, u, v, 0, AddFunctor<T>(), v_out);
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, g, v, 0, AddFunctor<T>(), v_out);
} else {
// u = m * u + g
u_out_e.device(eigen_ctx) = m * u_e + g_e;
// v = u + v
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, u, v, 0, AddFunctor<T>(), v_out);
}
T* v_out_data = v_out->mutable_data<T>(ctx.GetPlace());
T* u_out_data = u_out->mutable_data<T>(ctx.GetPlace());
T* encode_grad_out_data = encode_grad_out->mutable_data<T>(
framework::DDim{2 * k}, ctx.GetPlace());
int buf_size = paddle::communication::dgc::get_buffer_size(k);
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(
ctx.GetPlace(), dev_ctx.stream());
auto tmp_ious_data = allocator.Allocate(buf_size);
void* buf = reinterpret_cast<void*>(tmp_ious_data->ptr());
if (!paddle::communication::dgc::k_select(
static_cast<void*>(encode_grad_out_data), k, v_out_data,
static_cast<int>(v_out->numel()), buf, dev_ctx.stream(),
u_out_data)) {
LOG(FATAL) << "v_out numel:" << v_out->numel();
}
auto grad_out = ctx.Output<framework::Tensor>("Grad_out");
math::SetConstant<DeviceContext, T> tset;
tset(dev_ctx, grad_out, static_cast<T>(0));
}
};
} // namespace operators
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save