|
|
|
@ -327,6 +327,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& input_names = op_.Input(name);
|
|
|
|
|
auto length = input_names.size();
|
|
|
|
|
if (length == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL,
|
|
|
|
|
"Input(%s) should have only one value, "
|
|
|
|
|
"but it have %d now",
|
|
|
|
@ -337,6 +340,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& output_names = op_.Output(name);
|
|
|
|
|
auto length = output_names.size();
|
|
|
|
|
if (length == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL,
|
|
|
|
|
"Output(%s) should have only one value, "
|
|
|
|
|
"but it have %d now",
|
|
|
|
@ -346,7 +352,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
|
|
|
|
|
bool HasInputs(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& input_names = op_.Input(name);
|
|
|
|
|
PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name);
|
|
|
|
|
if (input_names.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto& input : input_names) {
|
|
|
|
|
if (!block_.HasVar(input)) return false;
|
|
|
|
|
}
|
|
|
|
@ -355,7 +363,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
|
|
|
|
|
bool HasOutputs(const std::string& name) const override {
|
|
|
|
|
const std::vector<std::string>& output_names = op_.Output(name);
|
|
|
|
|
PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name);
|
|
|
|
|
if (output_names.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto& output : output_names) {
|
|
|
|
|
if (!block_.HasVar(output)) return false;
|
|
|
|
|
}
|
|
|
|
@ -421,13 +431,27 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
: op_(op), scope_(scope) {}
|
|
|
|
|
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
auto ipt = op_.Input(name);
|
|
|
|
|
auto& ins = Inputs(name);
|
|
|
|
|
size_t length = ins.size();
|
|
|
|
|
if (length == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
|
|
|
|
|
name);
|
|
|
|
|
auto ipt = ins[0];
|
|
|
|
|
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
|
|
|
|
|
return var != nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
auto ipt = op_.Output(name);
|
|
|
|
|
auto& outs = Outputs(name);
|
|
|
|
|
size_t length = outs.size();
|
|
|
|
|
if (length == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
|
|
|
|
|
name);
|
|
|
|
|
auto ipt = outs[0];
|
|
|
|
|
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
|
|
|
|
|
return var != nullptr;
|
|
|
|
|
}
|
|
|
|
|