API (BuildStrategy) error message enhancement. (#23462)

revert-23830-2.0-beta
liym27 5 years ago committed by GitHub
parent 674355a097
commit 06d4aa4e73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -192,11 +192,24 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
PADDLE_ENFORCE_GE(strategy_.trainer_id_, 0, "trainer_id_ >= 0");
PADDLE_ENFORCE_GE(
strategy_.trainer_id_, 0,
platform::errors::InvalidArgument(
"The trainer_id_ of strategy_ must be greater than or equal to 0, "
"but received strategy_.trainer_id_ = %d.",
strategy_.trainer_id_));
if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
PADDLE_ENFORCE_LT(static_cast<size_t>(strategy_.trainer_id_),
strategy_.trainers_endpoints_.size(),
"trainer_id_ < endpoints_ size");
PADDLE_ENFORCE_LT(
static_cast<size_t>(strategy_.trainer_id_),
strategy_.trainers_endpoints_.size(),
platform::errors::InvalidArgument(
"The trainer_id_ of strategy_ must be less than the "
"size of vector strategy_.trainers_endpoints_, "
"but received strategy_.trainer_id_ = %d, "
"the size of strategy_.trainers_endpoints_ is %d.",
static_cast<size_t>(strategy_.trainer_id_),
strategy_.trainers_endpoints_.size()));
}
VLOG(1) << "CollectiveContext:" << context->String();
}
@ -269,8 +282,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
PADDLE_ENFORCE_NE(FLAGS_use_mkldnn, true,
platform::errors::PreconditionNotMet(
"FLAGS_use_mkldnn has been set to True, but "
"PaddlePaddle is compiled without MKLDNN. "
"Please compile PaddlePaddle with MKLDNN first."));
#endif
}

Loading…
Cancel
Save