|
|
@ -151,16 +151,16 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
|
|
|
|
void OperatorBase::CheckAllInputOutputSet() const {
|
|
|
|
void OperatorBase::CheckAllInputOutputSet() const {
|
|
|
|
auto& info_map = OpInfoMap::Instance();
|
|
|
|
auto& info_map = OpInfoMap::Instance();
|
|
|
|
auto* op_info = info_map.GetNullable(Type());
|
|
|
|
auto* op_info = info_map.GetNullable(Type());
|
|
|
|
if (op_info == nullptr) return;
|
|
|
|
if (op_info == nullptr || op_info->proto_ == nullptr) return;
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& in : op_info->Proto().inputs()) {
|
|
|
|
for (auto& in : op_info->Proto().inputs()) {
|
|
|
|
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
|
|
|
|
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
|
|
|
|
"input %s is not set", in.name());
|
|
|
|
"Type %s's input %s is not set", Type(), in.name());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& out : op_info->Proto().outputs()) {
|
|
|
|
for (auto& out : op_info->Proto().outputs()) {
|
|
|
|
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
|
|
|
|
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
|
|
|
|
"output %s is not set", out.name());
|
|
|
|
"Type %s's output %s is not set", Type(), out.name());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|