refine pe when exception raises, test=develop (#20894)

custom_op_abi
Zeng Jinle 6 years ago committed by GitHub
parent 20cdff0e02
commit b0c0ffb9ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -77,9 +77,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_.IsCaught()) {
bool is_exception_free =
RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (!is_exception_free) {
ExecutionFinal(&fetch_ops);
}
} else {
@ -259,25 +259,25 @@ void FastThreadedSSAGraphExecutor::ExecutionFinal(
exception_.ReThrow();
}
void FastThreadedSSAGraphExecutor::RunTracedOps(
bool FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_.IsCaught()) {
return;
}
RunOpSync(op);
if (!RunOpSync(op)) return false;
}
return true;
}
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) {
exception_.Catch(std::current_exception());
return false;
}
}

@ -78,9 +78,9 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,

@ -81,9 +81,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) {
bool is_exception_free =
RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (!is_exception_free) {
ExecutionFinal(&fetch_ops);
}
} else {
@ -308,25 +308,25 @@ void ThreadedSSAGraphExecutor::RunOp(
RecordOps(op);
}
void ThreadedSSAGraphExecutor::RunTracedOps(
bool ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) {
return;
}
RunOpSync(op);
if (!RunOpSync(op)) return false;
}
return true;
}
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) {
exception_holder_.Catch(std::current_exception());
return false;
}
}

@ -109,9 +109,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
};
} // namespace details

Loading…
Cancel
Save