|
|
@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
|
|
|
|
const Scope& scope) {
|
|
|
|
const Scope& scope) {
|
|
|
|
for (auto& var_name_item : innames) {
|
|
|
|
for (auto& var_name_item : innames) {
|
|
|
|
std::vector<Variable*>& input_vars = inputs[var_name_item.first];
|
|
|
|
std::vector<Variable*>& input_vars = inputs[var_name_item.first];
|
|
|
|
|
|
|
|
input_vars.reserve(var_name_item.second.size());
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
input_vars.push_back(scope.FindVar(var_name));
|
|
|
|
input_vars.push_back(scope.FindVar(var_name));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto& var_name_item : outnames) {
|
|
|
|
for (auto& var_name_item : outnames) {
|
|
|
|
std::vector<Variable*>& output_vars = outputs[var_name_item.first];
|
|
|
|
std::vector<Variable*>& output_vars = outputs[var_name_item.first];
|
|
|
|
|
|
|
|
output_vars.reserve(var_name_item.second.size());
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
output_vars.push_back(scope.FindVar(var_name));
|
|
|
|
output_vars.push_back(scope.FindVar(var_name));
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
|
|
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
// has only one output
|
|
|
|
// has only one output
|
|
|
|
const auto& outs = op_.Outputs();
|
|
|
|
const auto& outs = ctx_.outputs;
|
|
|
|
auto it = outs.find(name);
|
|
|
|
auto it = outs.find(name);
|
|
|
|
if (it == outs.end()) {
|
|
|
|
if (it == outs.end()) {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
const auto& out = it->second;
|
|
|
|
const auto& out = it->second;
|
|
|
|
if (out.size() == 0 || out[0] == kEmptyVarName) {
|
|
|
|
if (out.size() == 0) {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
PADDLE_ENFORCE_EQ(out.size(), 1UL,
|
|
|
|
PADDLE_ENFORCE_EQ(out.size(), 1UL,
|
|
|
|
"Output %s should not have more than one outputs", name);
|
|
|
|
"Output %s should not have more than one outputs", name);
|
|
|
|
return scope_.FindVar(out[0]) != nullptr;
|
|
|
|
return out[0] != nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool HasInputs(const std::string& name) const override {
|
|
|
|
bool HasInputs(const std::string& name) const override {
|
|
|
|
if (!op_.HasInputs(name)) {
|
|
|
|
const auto& ins = ctx_.inputs;
|
|
|
|
return false;
|
|
|
|
auto it = ins.find(name);
|
|
|
|
}
|
|
|
|
if (it == ins.end()) {
|
|
|
|
auto inputs = op_.Inputs(name);
|
|
|
|
|
|
|
|
if (inputs.empty()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto& input : inputs) {
|
|
|
|
for (auto& input : it->second) {
|
|
|
|
if (scope_.FindVar(input) == nullptr) {
|
|
|
|
if (input == nullptr) {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool HasOutputs(const std::string& name) const override {
|
|
|
|
bool HasOutputs(const std::string& name) const override {
|
|
|
|
if (!op_.HasOutputs(name)) {
|
|
|
|
const auto& outs = ctx_.outputs;
|
|
|
|
return false;
|
|
|
|
auto it = outs.find(name);
|
|
|
|
}
|
|
|
|
if (it == outs.end()) {
|
|
|
|
auto outputs = op_.Outputs(name);
|
|
|
|
|
|
|
|
if (outputs.empty()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto& output : outputs) {
|
|
|
|
for (auto& output : it->second) {
|
|
|
|
if (scope_.FindVar(output) == nullptr) {
|
|
|
|
if (output == nullptr) {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -864,8 +862,7 @@ Scope* OperatorWithKernel::PrepareData(
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
|
|
|
|
auto& var_name = var_name_item.second[i];
|
|
|
|
auto& var_name = var_name_item.second[i];
|
|
|
|
auto* var = scope.FindVar(var_name);
|
|
|
|
auto* var = input_vars[i];
|
|
|
|
input_vars[i] = var;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Only tensor can be tranfer to another device.
|
|
|
|
// Only tensor can be tranfer to another device.
|
|
|
|
if (var == nullptr || !VarIsTensor(*var)) {
|
|
|
|
if (var == nullptr || !VarIsTensor(*var)) {
|
|
|
|