|
|
|
@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
|
|
|
|
|
RunImpl(scope, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OperatorBase::HasInputs(const std::string& name) const {
|
|
|
|
|
if (inputs_.find(name) != inputs_.end()) {
|
|
|
|
|
return true;
|
|
|
|
|
} else {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string OperatorBase::Input(const std::string& name) const {
|
|
|
|
|
auto& ins = Inputs(name);
|
|
|
|
|
PADDLE_ENFORCE_LE(ins.size(), 1UL,
|
|
|
|
@ -109,6 +117,14 @@ const std::vector<std::string>& OperatorBase::Inputs(
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OperatorBase::HasOutputs(const std::string& name) const {
|
|
|
|
|
if (outputs_.find(name) != outputs_.end()) {
|
|
|
|
|
return true;
|
|
|
|
|
} else {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string OperatorBase::Output(const std::string& name) const {
|
|
|
|
|
auto& outs = Outputs(name);
|
|
|
|
|
PADDLE_ENFORCE_LE(outs.size(), 1UL,
|
|
|
|
@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const {
|
|
|
|
|
if (op_info == nullptr || op_info->proto_ == nullptr) return;
|
|
|
|
|
|
|
|
|
|
for (auto& in : op_info->Proto().inputs()) {
|
|
|
|
|
if (!in.dispensable()) {
|
|
|
|
|
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
|
|
|
|
|
"Type %s's input %s is not set", Type(), in.name());
|
|
|
|
|
"Operator %s's input, %s, is not set", Type(), in.name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& out : op_info->Proto().outputs()) {
|
|
|
|
|
if (!out.dispensable()) {
|
|
|
|
|
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
|
|
|
|
|
"Type %s's output %s is not set", Type(), out.name());
|
|
|
|
|
"Operator %s's output, %s, is not set", Type(),
|
|
|
|
|
out.name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
: op_(op), scope_(scope) {}
|
|
|
|
|
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
if (!op_.HasInputs(name)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto& ins = Inputs(name);
|
|
|
|
|
size_t length = ins.size();
|
|
|
|
|
if (length == 0) {
|
|
|
|
@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
if (!op_.HasOutputs(name)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto& outs = Outputs(name);
|
|
|
|
|
size_t length = outs.size();
|
|
|
|
|
if (length == 0) {
|
|
|
|
@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasInputs(const std::string& name) const override {
|
|
|
|
|
if (!op_.HasInputs(name)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto inputs = op_.Inputs(name);
|
|
|
|
|
if (inputs.empty()) {
|
|
|
|
|
return false;
|
|
|
|
@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutputs(const std::string& name) const override {
|
|
|
|
|
if (!op_.HasOutputs(name)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto outputs = op_.Outputs(name);
|
|
|
|
|
if (outputs.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|