Fix hang when input is duplicated (#10709)

fix_gru_py
Yu Yang 7 years ago committed by Yang Yang(Tony)
parent d73f2bd6bd
commit 14248a64d7

@ -70,6 +70,14 @@ class OpHandleBase {
const std::vector<VarHandleBase *> &Inputs() const { return inputs_; } const std::vector<VarHandleBase *> &Inputs() const { return inputs_; }
size_t NoDupInputSize() const {
std::unordered_set<VarHandleBase *> res;
for (auto *var : inputs_) {
res.emplace(var);
}
return res.size();
}
const std::vector<VarHandleBase *> &Outputs() const { return outputs_; } const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }
protected: protected:

@ -174,7 +174,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
void ThreadedSSAGraphExecutor::InsertPendingOp( void ThreadedSSAGraphExecutor::InsertPendingOp(
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const { OpHandleBase *op_instance) const {
pending_ops->insert({op_instance, op_instance->Inputs().size()}); pending_ops->insert({op_instance, op_instance->NoDupInputSize()});
} }
void ThreadedSSAGraphExecutor::InsertPendingVar( void ThreadedSSAGraphExecutor::InsertPendingVar(

Loading…
Cancel
Save