refine pe when exception raises, test=develop (#20894)

custom_op_abi
Zeng Jinle 6 years ago committed by GitHub
parent 20cdff0e02
commit b0c0ffb9ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -77,9 +77,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
// run the recorded operators directly. This strategy could make the // run the recorded operators directly. This strategy could make the
// execution faster. // execution faster.
VLOG(3) << "Run the traced ops."; VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_); bool is_exception_free =
RunTracedOps(fetch_ops); RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (exception_.IsCaught()) { if (!is_exception_free) {
ExecutionFinal(&fetch_ops); ExecutionFinal(&fetch_ops);
} }
} else { } else {
@ -259,25 +259,25 @@ void FastThreadedSSAGraphExecutor::ExecutionFinal(
exception_.ReThrow(); exception_.ReThrow();
} }
void FastThreadedSSAGraphExecutor::RunTracedOps( bool FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) { const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) { for (auto &op : traced_ops) {
if (exception_.IsCaught()) { if (!RunOpSync(op)) return false;
return;
}
RunOpSync(op);
} }
return true;
} }
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try { try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) { } catch (...) {
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
return false;
} }
} }

@ -78,9 +78,9 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops); inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op); inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops); bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps( void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches, const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,

@ -81,9 +81,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
// run the recorded operators directly. This strategy could make the // run the recorded operators directly. This strategy could make the
// execution faster. // execution faster.
VLOG(3) << "Run the traced ops."; VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_); bool is_exception_free =
RunTracedOps(fetch_ops); RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) { if (!is_exception_free) {
ExecutionFinal(&fetch_ops); ExecutionFinal(&fetch_ops);
} }
} else { } else {
@ -308,25 +308,25 @@ void ThreadedSSAGraphExecutor::RunOp(
RecordOps(op); RecordOps(op);
} }
void ThreadedSSAGraphExecutor::RunTracedOps( bool ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) { const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) { for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) { if (!RunOpSync(op)) return false;
return;
}
RunOpSync(op);
} }
return true;
} }
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try { try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
return false;
} }
} }

@ -109,9 +109,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops); inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op); inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops); bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
}; };
} // namespace details } // namespace details

Loading…
Cancel
Save