|
|
|
@ -119,19 +119,19 @@ class KernelContext {
|
|
|
|
|
: op_(*op), scope_(scope), device_context_(device_context) {}
|
|
|
|
|
|
|
|
|
|
const Variable* Input(int index) const {
|
|
|
|
|
return scope_->GetVariable(op_.inputs_[index]);
|
|
|
|
|
return scope_->FindVar(op_.inputs_[index]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Variable* Output(int index) const {
|
|
|
|
|
return scope_->GetVariable(op_.outputs_[index]);
|
|
|
|
|
return scope_->FindVar(op_.outputs_[index]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const Variable* Input(const std::string& name) const {
|
|
|
|
|
return scope_->GetVariable(op_.Input(name));
|
|
|
|
|
return scope_->FindVar(op_.Input(name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const Variable* Output(const std::string& name) const {
|
|
|
|
|
return scope_->GetVariable(op_.Output(name));
|
|
|
|
|
return scope_->FindVar(op_.Output(name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<const Variable*> Inputs(const std::string& name) const {
|
|
|
|
@ -139,7 +139,7 @@ class KernelContext {
|
|
|
|
|
std::vector<const Variable*> res;
|
|
|
|
|
std::transform(
|
|
|
|
|
names.begin(), names.end(), res.begin(),
|
|
|
|
|
[this](const std::string& name) { return scope_->GetVariable(name); });
|
|
|
|
|
[this](const std::string& name) { return scope_->FindVar(name); });
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -148,7 +148,7 @@ class KernelContext {
|
|
|
|
|
std::vector<const Variable*> res;
|
|
|
|
|
std::transform(
|
|
|
|
|
names.begin(), names.end(), res.begin(),
|
|
|
|
|
[this](const std::string& name) { return scope_->GetVariable(name); });
|
|
|
|
|
[this](const std::string& name) { return scope_->FindVar(name); });
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -244,7 +244,7 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
container->reserve(var_names.size());
|
|
|
|
|
VarToTensor<T> convert;
|
|
|
|
|
for (auto& name : var_names) {
|
|
|
|
|
auto var = scope->GetVariable(name);
|
|
|
|
|
auto var = scope->FindVar(name);
|
|
|
|
|
if (var != nullptr) {
|
|
|
|
|
container->push_back(convert(var));
|
|
|
|
|
} else {
|
|
|
|
|