|
|
|
@ -555,18 +555,17 @@ Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const {
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
auto names = op().Outputs(name);
|
|
|
|
|
auto it = ctx_.outputs.find(name);
|
|
|
|
|
if (it == ctx_.outputs.end()) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
const std::vector<Variable*>& vars = it->second;
|
|
|
|
|
std::vector<Tensor*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) -> Tensor* {
|
|
|
|
|
auto var = scope_.FindVar(sub_name);
|
|
|
|
|
if (var == nullptr) return nullptr;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
var->IsType<LoDTensor>(),
|
|
|
|
|
"%s should be LoDTensor, but the received type is %s",
|
|
|
|
|
sub_name, ToTypeName(var->Type()));
|
|
|
|
|
return var->GetMutable<LoDTensor>();
|
|
|
|
|
res.reserve(vars.size());
|
|
|
|
|
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
|
|
|
|
|
[&](Variable* var) -> Tensor* {
|
|
|
|
|
return var == nullptr ? nullptr
|
|
|
|
|
: var->GetMutable<LoDTensor>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|