|
|
|
@ -108,11 +108,11 @@ class OperatorContext {
|
|
|
|
|
|
|
|
|
|
size_t OutputSize() const { return op_.outputs_.size(); }
|
|
|
|
|
|
|
|
|
|
const Variable* InputVar(const size_t& index) const {
|
|
|
|
|
const Variable* InputVar(const size_t index) const {
|
|
|
|
|
return scope_->GetVariable(op_.inputs_.at(index));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Variable* OutputVar(const size_t& index) const {
|
|
|
|
|
Variable* OutputVar(const size_t index) const {
|
|
|
|
|
return scope_->GetVariable(op_.outputs_.at(index));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -146,23 +146,31 @@ class OperatorContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* Input(const size_t& index) const {
|
|
|
|
|
return &(InputVar(index)->Get<T>());
|
|
|
|
|
const T* Input(const size_t index) const {
|
|
|
|
|
auto var = InputVar(index);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index);
|
|
|
|
|
return &var->Get<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* Output(const size_t& index) const {
|
|
|
|
|
return OutputVar(index)->GetMutable<T>();
|
|
|
|
|
T* Output(const size_t index) const {
|
|
|
|
|
auto var = OutputVar(index);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index);
|
|
|
|
|
return var->GetMutable<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* Input(const std::string& name) const {
|
|
|
|
|
return &(InputVar(name)->Get<T>());
|
|
|
|
|
auto var = InputVar(name);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name);
|
|
|
|
|
return &var->Get<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* Output(const std::string& name) const {
|
|
|
|
|
return OutputVar(name)->GetMutable<T>();
|
|
|
|
|
auto var = OutputVar(name);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name);
|
|
|
|
|
return var->GetMutable<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -171,8 +179,12 @@ class OperatorContext {
|
|
|
|
|
std::vector<const T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
|
return &scope_->GetVariable(name)->Get<T>();
|
|
|
|
|
[&](const std::string& sub_name) {
|
|
|
|
|
auto var = scope_->GetVariable(sub_name);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr,
|
|
|
|
|
"MultiInput(%s:%s) should not be nullptr",
|
|
|
|
|
name, sub_name);
|
|
|
|
|
return &var->Get<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
@ -183,8 +195,12 @@ class OperatorContext {
|
|
|
|
|
std::vector<const T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
|
return scope_->GetVariable(name)->GetMutable<T>();
|
|
|
|
|
[&](const std::string& sub_name) {
|
|
|
|
|
auto var = scope_->GetVariable(sub_name);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr,
|
|
|
|
|
"MultiOutput(%s:%s) should not be nullptr",
|
|
|
|
|
name, sub_name);
|
|
|
|
|
return var->GetMutable<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|