|
|
|
@ -230,14 +230,14 @@ class DygraphExecutionContext : public framework::ExecutionContext {
|
|
|
|
|
var_base_map_out_(var_base_map_out),
|
|
|
|
|
attrs_(attrs) {}
|
|
|
|
|
|
|
|
|
|
std::string InputName(const std::string& name) const {
|
|
|
|
|
std::string InputName(const std::string& name) const override {
|
|
|
|
|
auto it = var_base_map_in_.find(name);
|
|
|
|
|
PADDLE_ENFORCE_NE(it, var_base_map_in_.end(),
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Can not find [%s] in Input", name));
|
|
|
|
|
return it->second[0]->Name();
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> InputNames(const std::string& name) const {
|
|
|
|
|
std::vector<std::string> InputNames(const std::string& name) const override {
|
|
|
|
|
auto it = var_base_map_in_.find(name);
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
it, var_base_map_in_.end(),
|
|
|
|
@ -250,7 +250,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
|
|
|
|
|
return vec_res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string OuputName(const std::string& name) const {
|
|
|
|
|
std::string OutputName(const std::string& name) const override {
|
|
|
|
|
auto it = var_base_map_out_.find(name);
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
it, var_base_map_out_.end(),
|
|
|
|
@ -258,7 +258,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
|
|
|
|
|
return it->second[0]->Name();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> OutputNames(const std::string& name) const {
|
|
|
|
|
std::vector<std::string> OutputNames(const std::string& name) const override {
|
|
|
|
|
auto it = var_base_map_out_.find(name);
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
it, var_base_map_out_.end(),
|
|
|
|
@ -271,11 +271,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
|
|
|
|
|
return vec_res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasAttr(const std::string& name) const { return attrs_->count(name); }
|
|
|
|
|
bool HasAttr(const std::string& name) const override {
|
|
|
|
|
return attrs_->count(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const framework::AttributeMap& Attrs() const { return *attrs_; }
|
|
|
|
|
const framework::AttributeMap& Attrs() const override { return *attrs_; }
|
|
|
|
|
|
|
|
|
|
const framework::Attribute& GetAttr(const std::string& name) const {
|
|
|
|
|
const framework::Attribute& GetAttr(const std::string& name) const override {
|
|
|
|
|
auto it = attrs_->find(name);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
@ -285,7 +287,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> InNameList() const {
|
|
|
|
|
std::vector<std::string> InNameList() const override {
|
|
|
|
|
std::vector<std::string> vec_temp;
|
|
|
|
|
vec_temp.reserve(var_base_map_in_.size());
|
|
|
|
|
|
|
|
|
@ -295,21 +297,21 @@ class DygraphExecutionContext : public framework::ExecutionContext {
|
|
|
|
|
|
|
|
|
|
return vec_temp;
|
|
|
|
|
}
|
|
|
|
|
bool HasInput(const std::string& name) const {
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
auto it = var_base_map_in_.find(name);
|
|
|
|
|
return (it != var_base_map_in_.end() && it->second.size() > 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual bool HasOutput(const std::string& name) const {
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
auto it = var_base_map_out_.find(name);
|
|
|
|
|
return (it != var_base_map_out_.end() && it->second.size() > 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t InputSize(const std::string& name) const {
|
|
|
|
|
size_t InputSize(const std::string& name) const override {
|
|
|
|
|
return InputNames(name).size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t OutputSize(const std::string& name) const {
|
|
|
|
|
size_t OutputSize(const std::string& name) const override {
|
|
|
|
|
return OutputNames(name).size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -322,7 +324,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
|
|
|
|
|
return it->second.empty() ? nullptr : it->second[0]->MutableVar();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Variable* OutputVar(const std::string& name) const {
|
|
|
|
|
Variable* OutputVar(const std::string& name) const override {
|
|
|
|
|
auto it = var_base_map_out_.find(name);
|
|
|
|
|
if (it == var_base_map_out_.end()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|