|
|
|
@ -14,13 +14,11 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/program_desc.h"
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/feed_fetch_type.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
const std::string kFeedOpType = "feed";
|
|
|
|
|
const std::string kFetchOpType = "fetch";
|
|
|
|
|
|
|
|
|
|
BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
|
|
|
|
|
auto *b = desc_.add_blocks();
|
|
|
|
|
b->set_parent_idx(parent.ID());
|
|
|
|
@ -67,26 +65,26 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> ProgramDesc::GetFeedVarNames() {
|
|
|
|
|
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
|
|
|
|
|
BlockDesc *global_block = blocks_[0].get();
|
|
|
|
|
std::vector<std::string> feed_var_names;
|
|
|
|
|
std::vector<std::string> feed_target_names;
|
|
|
|
|
for (auto *op : global_block->AllOps()) {
|
|
|
|
|
if (op->Type() == "feed") {
|
|
|
|
|
feed_var_names.insert(feed_var_names.begin(), op->Output("Out")[0]);
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return feed_var_names;
|
|
|
|
|
return feed_target_names;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> ProgramDesc::GetFetchVarNames() {
|
|
|
|
|
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
|
|
|
|
|
BlockDesc *global_block = blocks_[0].get();
|
|
|
|
|
std::vector<std::string> fetch_var_names;
|
|
|
|
|
std::vector<std::string> fetch_target_names;
|
|
|
|
|
for (auto *op : global_block->AllOps()) {
|
|
|
|
|
if (op->Type() == "fetch") {
|
|
|
|
|
fetch_var_names.push_back(op->Input("X")[0]);
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
fetch_target_names.push_back(op->Input("X")[0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return fetch_var_names;
|
|
|
|
|
return fetch_target_names;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|