|
|
|
@ -16,9 +16,11 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <atomic>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "glog/logging.h" // For VLOG
|
|
|
|
@ -253,31 +255,6 @@ class ExecutionContext {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<Variable*> LegacyMultiInputVar(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
auto names = op_.Inputs(name);
|
|
|
|
|
std::vector<Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
|
return name == kEmptyVarName ? nullptr
|
|
|
|
|
: scope_.FindVar(name);
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<Variable*> LegacyMultiOutputVar(const std::string& name) const {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
std::vector<Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
|
return name == kEmptyVarName ? nullptr
|
|
|
|
|
: scope_.FindVar(name);
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* Input(const std::string& name) const {
|
|
|
|
|
auto* var = InputVar(name);
|
|
|
|
@ -290,22 +267,6 @@ class ExecutionContext {
|
|
|
|
|
return var == nullptr ? nullptr : var->GetMutable<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* LegacyInput(const std::string& name) const {
|
|
|
|
|
auto* var = LegacyInputVar(name);
|
|
|
|
|
return var == nullptr ? nullptr : &var->Get<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* LegacyOutput(const std::string& name) const {
|
|
|
|
|
auto var = LegacyOutputVar(name);
|
|
|
|
|
return var == nullptr ? nullptr : var->GetMutable<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const Variable* LegacyInputVar(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
Variable* LegacyOutputVar(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const std::vector<const T*> MultiInput(const std::string& name) const {
|
|
|
|
|
auto it = ctx_.inputs.find(name);
|
|
|
|
@ -338,32 +299,6 @@ class ExecutionContext {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const std::vector<const T*> LegacyMultiInput(const std::string& name) const {
|
|
|
|
|
auto names = op_.Inputs(name);
|
|
|
|
|
std::vector<const T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) -> const T* {
|
|
|
|
|
auto var = scope_.FindVar(sub_name);
|
|
|
|
|
return var == nullptr ? nullptr : &var->Get<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<T*> LegacyMultiOutput(const std::string& name) const {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
std::vector<T*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) -> T* {
|
|
|
|
|
auto var = scope_.FindVar(sub_name);
|
|
|
|
|
return var == nullptr ? nullptr : var->GetMutable<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::Place GetPlace() const { return device_context_.GetPlace(); }
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContextType>
|
|
|
|
@ -436,24 +371,13 @@ class ExecutionContext {
|
|
|
|
|
template <>
|
|
|
|
|
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const Tensor* ExecutionContext::LegacyInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|