|
|
|
@ -197,8 +197,31 @@ class ExecutionContext {
|
|
|
|
|
|
|
|
|
|
const std::vector<const Variable*> MultiInputVar(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
auto names = op_.Inputs(name);
|
|
|
|
|
auto it = ctx_.inputs.find(name);
|
|
|
|
|
if (it == ctx_.inputs.end()) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
std::vector<const Variable*> res;
|
|
|
|
|
res.reserve(it->second.size());
|
|
|
|
|
std::transform(it->second.begin(), it->second.end(),
|
|
|
|
|
std::back_inserter(res),
|
|
|
|
|
[this](Variable* var) { return var; });
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<Variable*> MultiOutputVar(const std::string& name) const {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
auto it = ctx_.outputs.find(name);
|
|
|
|
|
if (it == ctx_.outputs.end()) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<Variable*> LegacyMultiInputVar(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
auto names = op_.Inputs(name);
|
|
|
|
|
std::vector<Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
@ -208,7 +231,7 @@ class ExecutionContext {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<Variable*> MultiOutputVar(const std::string& name) const {
|
|
|
|
|
std::vector<Variable*> LegacyMultiOutputVar(const std::string& name) const {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
std::vector<Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
@ -250,6 +273,38 @@ class ExecutionContext {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const std::vector<const T*> MultiInput(const std::string& name) const {
|
|
|
|
|
auto it = ctx_.inputs.find(name);
|
|
|
|
|
if (it == ctx_.inputs.end()) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
const std::vector<Variable*>& vars = it->second;
|
|
|
|
|
std::vector<const T*> res;
|
|
|
|
|
res.reserve(vars.size());
|
|
|
|
|
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
|
|
|
|
|
[&](Variable* var) -> const T* {
|
|
|
|
|
return var == nullptr ? nullptr : &var->Get<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<T*> MultiOutput(const std::string& name) const {
|
|
|
|
|
auto it = ctx_.outputs.find(name);
|
|
|
|
|
if (it == ctx_.outputs.end()) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
const std::vector<Variable*>& vars = it->second;
|
|
|
|
|
std::vector<T*> res;
|
|
|
|
|
res.reserve(vars.size());
|
|
|
|
|
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
|
|
|
|
|
[&](Variable* var) -> T* {
|
|
|
|
|
return var == nullptr ? nullptr : var->GetMutable<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const std::vector<const T*> LegacyMultiInput(const std::string& name) const {
|
|
|
|
|
auto names = op_.Inputs(name);
|
|
|
|
|
std::vector<const T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
@ -262,7 +317,7 @@ class ExecutionContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<T*> MultiOutput(const std::string& name) const {
|
|
|
|
|
std::vector<T*> LegacyMultiOutput(const std::string& name) const {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
std::vector<T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
@ -321,6 +376,10 @@ template <>
|
|
|
|
|
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|