Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into gen_nccl_id_op

fix_gru_py
typhoonzero 7 years ago
commit 928418a9ac

@ -155,7 +155,7 @@ into offsets
3 2+3 4+5 1+9 2+10 3+12 3 2+3 4+5 1+9 2+10 3+12
``` ```
so we know that the first sentence is from word 0 to word 3, and the second sentence from work 3 to word 5. so we know that the first sentence is from word 0 to word 3, and the second sentence from word 3 to word 5.
Similarly, the lengths in the top level LoD Similarly, the lengths in the top level LoD

@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale) platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs),
balance_parameter_opt_between_cards_(
balance_parameter_opt_between_cards) {
#else #else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale) const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes) { local_scopes_(local_scopes),
balance_parameter_opt_between_cards_(
balance_parameter_opt_between_cards) {
#endif #endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// Find "send" op first for split is in front of send. // Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program); OpDesc *send_op = GetSendOpDesc(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size());
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") { if (op->Type() == "send") {
@ -139,17 +151,33 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
CreateComputationalOps(&result, *op, places_.size()); int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name);
}
}
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (IsSparseGradient(var_types, og)) { if (balance_parameter_opt_between_cards_) {
CreateReduceOp(&result, og, 0); CreateReduceOp(&result, og, cur_device_id);
CreateBroadcastOp(&result, og, 0); var_name_on_devices[cur_device_id].emplace(og);
bcast_var_name_set[cur_device_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_device_id = (cur_device_id + 1) % places_.size();
} else { } else {
InsertNCCLAllReduceOp(&result, og); if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0);
} else {
InsertNCCLAllReduceOp(&result, og);
}
} }
} }
} }
@ -157,6 +185,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
} }
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
harzaeds need to be handled. harzaeds need to be handled.
@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
if (!balance_parameter_opt_between_cards_) {
return -1;
}
int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
}
}
return var_dev_id;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle

@ -36,13 +36,15 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, platform::NCCLContextMap *nccl_ctxs,
bool use_default_grad_scale); bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool use_default_grad_scale); bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
@ -60,6 +62,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool balance_parameter_opt_between_cards_;
bool use_default_grad_scale_; bool use_default_grad_scale_;
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
@ -84,6 +87,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,

@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay, Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool use_default_grad_scale, size_t num_trainers, size_t trainer_id) bool use_default_grad_scale, bool balance_parameter_opt_between_cards,
size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
@ -99,11 +100,12 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), use_default_grad_scale); member_->nccl_ctxs_.get(), use_default_grad_scale,
balance_parameter_opt_between_cards);
#else #else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(
params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
use_default_grad_scale); use_default_grad_scale, balance_parameter_opt_between_cards);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);

@ -41,6 +41,7 @@ class ParallelExecutor {
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale, bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards,
size_t num_trainers = 0, size_t trainer_id = 0); size_t num_trainers = 0, size_t trainer_id = 0);
~ParallelExecutor(); ~ParallelExecutor();

@ -276,6 +276,11 @@ foreach(src ${READER_LIBRARY})
set(OP_LIBRARY ${src} ${OP_LIBRARY}) set(OP_LIBRARY ${src} ${OP_LIBRARY})
endforeach() endforeach()
add_subdirectory(detection)
foreach(src ${DETECTION_LIBRARY})
set(OP_LIBRARY ${src} ${OP_LIBRARY})
endforeach()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)

@ -0,0 +1,29 @@
set(LOCAL_DETECTION_LIBS)
function(detection_library TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(options "")
set(common_deps op_registry)
set(pybind_flag 0)
cmake_parse_arguments(detection_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
op_library(${TARGET_NAME} SRCS ${detection_library_SRCS} DEPS ${common_deps} ${detection_library_DEPS})
set(LOCAL_DETECTION_LIBS
${TARGET_NAME}
${LOCAL_DETECTION_LIBS}
PARENT_SCOPE)
endfunction()
detection_library(bipartite_match_op SRCS bipartite_match_op.cc)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
target_assign_op.cu)
# Export local libraries to parent
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)

@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/box_coder_op.h" #include "paddle/fluid/operators/detection/box_coder_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/box_coder_op.h" #include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/iou_similarity_op.h" #include "paddle/fluid/operators/detection/iou_similarity_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/iou_similarity_op.h" #include "paddle/fluid/operators/detection/iou_similarity_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/prior_box_op.h" #include "paddle/fluid/operators/detection/prior_box_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/prior_box_op.h" #include "paddle/fluid/operators/detection/prior_box_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/target_assign_op.h" #include "paddle/fluid/operators/detection/target_assign_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/target_assign_op.h" #include "paddle/fluid/operators/detection/target_assign_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -83,8 +83,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_service_->SetProgram(&empty_program); rpc_service_->SetProgram(&empty_program);
rpc_service_->SetExecutor(&executor); rpc_service_->SetExecutor(&executor);
server_thread_.reset(new std::thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_))); std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_));
rpc_service_->SetCond(0); rpc_service_->SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service_->Get(); auto recv = rpc_service_->Get();
@ -92,13 +92,12 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
// TODO(wuyi): reinit nccl communicators // TODO(wuyi): reinit nccl communicators
server_thread_->join(); server_thread.join();
delete rpc_service_; delete rpc_service_;
} }
protected: protected:
mutable detail::AsyncGRPCServer* rpc_service_ = nullptr; mutable detail::AsyncGRPCServer* rpc_service_ = nullptr;
mutable std::shared_ptr<std::thread> server_thread_;
}; };
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {

@ -96,10 +96,22 @@ struct CUBlas<platform::float16> {
reinterpret_cast<__half *>(C), ldc)); reinterpret_cast<__half *>(C), ldc));
} }
template <typename... ARGS> static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa,
static void GEMM_BATCH(ARGS... args) { cublasOperation_t transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
long long int strideA, const float16 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const float16 *beta, float16 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...)); PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda, strideA,
reinterpret_cast<const __half *>(B), ldb, strideB,
reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
ldc, strideC, batchCount));
#else #else
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
#endif #endif

@ -172,9 +172,9 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
c_array.data(), &ldc, 1 /* group_count */, &batchCount); c_array.data(), &ldc, 1 /* group_count */, &batchCount);
#else #else
for (int k = 0; k < batchCount; ++k) { for (int k = 0; k < batchCount; ++k) {
const float *Ak = &A[k * strideA]; auto *Ak = &A[k * strideA];
const float *Bk = &B[k * strideB]; auto *Bk = &B[k * strideB];
float *Ck = &C[k * M * N]; auto *Ck = &C[k * M * N];
this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck);
} }
#endif #endif

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save