|
|
|
@ -117,10 +117,16 @@ void ProgramDesc::InitFromProto() {
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
|
|
|
|
|
auto &global_block = Block(0);
|
|
|
|
|
// The order of feed_target_names must follow the index specified in `col`.
|
|
|
|
|
// since feed operator's order doesn't necessary follow 'col'.
|
|
|
|
|
std::vector<std::string> feed_target_names;
|
|
|
|
|
for (auto *op : global_block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]);
|
|
|
|
|
int col = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
if (col >= feed_target_names.size()) {
|
|
|
|
|
feed_target_names.resize(col + 1);
|
|
|
|
|
}
|
|
|
|
|
feed_target_names[col] = op->Output("Out")[0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return feed_target_names;
|
|
|
|
@ -128,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
|
|
|
|
|
auto &global_block = Block(0);
|
|
|
|
|
// The order of fetch_target_names must follow the index specified in `col`.
|
|
|
|
|
// since fetch operator's order doesn't necessary follow 'col'.
|
|
|
|
|
std::vector<std::string> fetch_target_names;
|
|
|
|
|
for (auto *op : global_block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
fetch_target_names.push_back(op->Input("X")[0]);
|
|
|
|
|
int col = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
if (col >= fetch_target_names.size()) {
|
|
|
|
|
fetch_target_names.resize(col + 1);
|
|
|
|
|
}
|
|
|
|
|
fetch_target_names[col] = op->Input("X")[0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return fetch_target_names;
|
|
|
|
|