Merge pull request #11288 from JiayiFeng/fix_bug_of_ExecutionContext

fix bugs in the implementation of 'HasInput' and 'HasOutput'
wangkuiyi-patch-1
fengjiayi 7 years ago committed by GitHub
commit 5803115720
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -293,6 +293,38 @@ static Tensor* GetMutableTensorFromVar(Variable* var) {
}
}
bool ExecutionContext::HasInput(const std::string& name) const {
if (!op_.HasInputs(name)) {
return false;
}
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Input %s should not have more than one inputs", name);
auto arg = ins[0];
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
return var != nullptr;
}
bool ExecutionContext::HasOutput(const std::string& name) const {
if (!op_.HasOutputs(name)) {
return false;
}
auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Output %s should not have more than one inputs", name);
auto arg = outs[0];
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
return var != nullptr;
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);

@ -191,9 +191,9 @@ class ExecutionContext {
return op_.Attr<T>(name);
}
bool HasInput(const std::string& name) const { return op_.HasInputs(name); }
bool HasInput(const std::string& name) const;
bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); }
bool HasOutput(const std::string& name) const;
size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();

Loading…
Cancel
Save