|
|
|
@ -70,11 +70,13 @@ void RenameAndGetOutputs(
|
|
|
|
|
std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/>
|
|
|
|
|
same_hierarchy_conv2d_num_map;
|
|
|
|
|
|
|
|
|
|
auto set_var_shape = [&](const std::string &arg_value) {
|
|
|
|
|
auto arg_var_node = graph_var_map.find(arg_value);
|
|
|
|
|
auto add_block_var = [&](const std::string &graph_arg,
|
|
|
|
|
const std::string &block_arg) {
|
|
|
|
|
auto arg_var_node = graph_var_map.find(graph_arg);
|
|
|
|
|
PADDLE_ENFORCE(arg_var_node != graph_var_map.end());
|
|
|
|
|
auto *var_t = block_desc->Var(arg_value);
|
|
|
|
|
auto *var_t = block_desc->Var(block_arg);
|
|
|
|
|
var_t->SetShape(arg_var_node->second->Var()->GetShape());
|
|
|
|
|
var_t->SetDataType(arg_var_node->second->Var()->GetDataType());
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (size_t index = 0; index < block_desc->OpSize(); ++index) {
|
|
|
|
@ -99,15 +101,16 @@ void RenameAndGetOutputs(
|
|
|
|
|
const std::string arg_value_with_id =
|
|
|
|
|
arg_value + std::to_string(var2id[arg_value]);
|
|
|
|
|
|
|
|
|
|
bool is_var_in_graph = graph_var_map.count(arg_value);
|
|
|
|
|
|
|
|
|
|
if (input_names_with_id.count(arg_value_with_id)) {
|
|
|
|
|
replaced_names.push_back(arg_value);
|
|
|
|
|
if (graph_var_map.count(arg_value)) {
|
|
|
|
|
add_block_var(arg_value, arg_value);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
replaced_names.push_back(arg_value_with_id);
|
|
|
|
|
if (graph_var_map.count(arg_value)) {
|
|
|
|
|
add_block_var(arg_value, arg_value_with_id);
|
|
|
|
|
}
|
|
|
|
|
if (is_var_in_graph) {
|
|
|
|
|
set_var_shape(arg_value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
in_var->clear_arguments();
|
|
|
|
@ -147,11 +150,9 @@ void RenameAndGetOutputs(
|
|
|
|
|
const std::string arg_value_with_id =
|
|
|
|
|
arg_value + std::to_string(var2id[arg_value]);
|
|
|
|
|
|
|
|
|
|
bool is_var_in_graph = graph_var_map.count(arg_value);
|
|
|
|
|
if (is_var_in_graph) {
|
|
|
|
|
set_var_shape(arg_value);
|
|
|
|
|
if (graph_var_map.count(arg_value)) {
|
|
|
|
|
add_block_var(arg_value, arg_value_with_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (output_names_with_id->count(arg_value_with_id)) {
|
|
|
|
|
(*output_name_map)[arg_value] = arg_value_with_id;
|
|
|
|
|
}
|
|
|
|
|