|
|
|
@ -14,35 +14,34 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "dataset/util/cond_var.h"
|
|
|
|
|
#include <exception>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include "dataset/util/services.h"
|
|
|
|
|
#include "dataset/util/task_manager.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
|
CondVar::CondVar() : svc_(nullptr), my_name_(std::move(Services::GetUniqueID())) {}
|
|
|
|
|
CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {}
|
|
|
|
|
|
|
|
|
|
Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred) {
|
|
|
|
|
// Append an additional condition on top of the given predicate.
|
|
|
|
|
// We will also bail out if this cv got interrupted.
|
|
|
|
|
auto f = [this, &pred]() -> bool { return (pred() || (CurState() == State::kInterrupted)); };
|
|
|
|
|
// If we have interrupt service, just wait on the cv unconditionally.
|
|
|
|
|
// Otherwise fall back to the old way of checking interrupt.
|
|
|
|
|
if (svc_) {
|
|
|
|
|
cv_.wait(*lck, f);
|
|
|
|
|
if (CurState() == State::kInterrupted) {
|
|
|
|
|
Task *my_task = TaskManager::FindMe();
|
|
|
|
|
if (my_task->IsMasterThread() && my_task->CaughtSevereException()) {
|
|
|
|
|
return TaskManager::GetMasterThreadRc();
|
|
|
|
|
} else {
|
|
|
|
|
return Status(StatusCode::kInterrupted);
|
|
|
|
|
try {
|
|
|
|
|
if (svc_ != nullptr) {
|
|
|
|
|
// If this cv registers with a global resource tracking, then wait unconditionally.
|
|
|
|
|
auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); };
|
|
|
|
|
cv_.wait(*lck, f);
|
|
|
|
|
// If we are interrupted, override the return value if this is the master thread.
|
|
|
|
|
// Master thread is being interrupted mostly because of some thread is reporting error.
|
|
|
|
|
RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus()));
|
|
|
|
|
} else {
|
|
|
|
|
// Otherwise we wake up once a while to check for interrupt (for this thread).
|
|
|
|
|
auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); };
|
|
|
|
|
while (!f()) {
|
|
|
|
|
(void)cv_.wait_for(*lck, std::chrono::milliseconds(1));
|
|
|
|
|
}
|
|
|
|
|
RETURN_IF_INTERRUPTED();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
RETURN_IF_NOT_OK(interruptible_wait(&cv_, lck, pred));
|
|
|
|
|
if (CurState() == State::kInterrupted) {
|
|
|
|
|
return Status(StatusCode::kInterrupted);
|
|
|
|
|
}
|
|
|
|
|
} catch (const std::exception &e) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED(e.what());
|
|
|
|
|
}
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) {
|
|
|
|
|
return rc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CondVar::Interrupt() {
|
|
|
|
|
RETURN_IF_NOT_OK(IntrpResource::Interrupt());
|
|
|
|
|
void CondVar::Interrupt() {
|
|
|
|
|
IntrpResource::Interrupt();
|
|
|
|
|
cv_.notify_all();
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string CondVar::my_name() const { return my_name_; }
|
|
|
|
|