fix InferShapeContext Has interface (#4994)

revert-4814-Add_sequence_project_op
Qiao Longfei 8 years ago committed by GitHub
parent d0cfbba429
commit e7f627036a

@ -327,6 +327,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name); const std::vector<std::string>& input_names = op_.Input(name);
auto length = input_names.size(); auto length = input_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have only one value, " "Input(%s) should have only one value, "
"but it have %d now", "but it have %d now",
@ -337,6 +340,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name); const std::vector<std::string>& output_names = op_.Output(name);
auto length = output_names.size(); auto length = output_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have only one value, " "Output(%s) should have only one value, "
"but it have %d now", "but it have %d now",
@ -346,7 +352,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name); const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name); if (input_names.empty()) {
return false;
}
for (auto& input : input_names) { for (auto& input : input_names) {
if (!block_.HasVar(input)) return false; if (!block_.HasVar(input)) return false;
} }
@ -355,7 +363,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name); const std::vector<std::string>& output_names = op_.Output(name);
PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name); if (output_names.empty()) {
return false;
}
for (auto& output : output_names) { for (auto& output : output_names) {
if (!block_.HasVar(output)) return false; if (!block_.HasVar(output)) return false;
} }
@ -421,13 +431,27 @@ class RuntimeInferShapeContext : public InferShapeContext {
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
auto ipt = op_.Input(name); auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
auto ipt = op_.Output(name); auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }

Loading…
Cancel
Save