|
|
|
@ -18,7 +18,7 @@ import time
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.evaluator import Evaluator
|
|
|
|
|
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
|
|
|
|
|
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
|
|
|
|
|
from . import core
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
@ -1374,3 +1374,101 @@ def get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
if success_num > current_dir:
|
|
|
|
|
current_dir = success_num
|
|
|
|
|
return current_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_test_program(filelist, program=None, startup_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Transpile current train program to a program to read test dataset
|
|
|
|
|
if the program is using reader ops like "open_files_op".
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _copy_reader_var_(block, var, new_name=None):
|
|
|
|
|
if new_name == None:
|
|
|
|
|
new_name = var.name
|
|
|
|
|
new_var = block.create_var(
|
|
|
|
|
name=str(new_name), type=core.VarDesc.VarType.READER)
|
|
|
|
|
new_var.desc.set_shapes(var.desc.shapes())
|
|
|
|
|
new_var.desc.set_dtypes(var.desc.dtypes())
|
|
|
|
|
new_var.persistable = True
|
|
|
|
|
return new_var
|
|
|
|
|
|
|
|
|
|
def _get_test_reader_name(train_reader_name):
|
|
|
|
|
return train_reader_name + "_test"
|
|
|
|
|
|
|
|
|
|
def _is_reader_op(op):
|
|
|
|
|
block = op.block
|
|
|
|
|
if "Out" in op.output_names:
|
|
|
|
|
reader_out = block.vars[op.output("Out")[0]]
|
|
|
|
|
if reader_out.type == core.VarDesc.VarType.READER:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if program == None:
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
if startup_program == None:
|
|
|
|
|
startup_program = default_startup_program()
|
|
|
|
|
startup_block = startup_program.global_block()
|
|
|
|
|
|
|
|
|
|
# 1. find out the orignal reader var name
|
|
|
|
|
startup_reader_op_list = []
|
|
|
|
|
|
|
|
|
|
for op in startup_block.ops:
|
|
|
|
|
if _is_reader_op(op):
|
|
|
|
|
startup_reader_op_list.append(op)
|
|
|
|
|
|
|
|
|
|
if len(startup_reader_op_list) == 0:
|
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
|
root_reader_op = startup_reader_op_list[0]
|
|
|
|
|
train_test_reader_map = {}
|
|
|
|
|
# 2. add operators to startup to read open and read test data files
|
|
|
|
|
for op in startup_reader_op_list:
|
|
|
|
|
assert (len(op.output("Out")) == 1)
|
|
|
|
|
train_reader_name = op.output("Out")[0]
|
|
|
|
|
train_reader = startup_block.vars[train_reader_name]
|
|
|
|
|
test_reader = _copy_reader_var_(
|
|
|
|
|
startup_block,
|
|
|
|
|
train_reader,
|
|
|
|
|
new_name=_get_test_reader_name(train_reader_name))
|
|
|
|
|
train_test_reader_map[train_reader.name] = test_reader
|
|
|
|
|
|
|
|
|
|
test_op_inputs = {}
|
|
|
|
|
for name in op.input_names:
|
|
|
|
|
train_arg_names = op.input(name)
|
|
|
|
|
test_arg_vars = []
|
|
|
|
|
for arg_name in train_arg_names:
|
|
|
|
|
arg_var = train_test_reader_map[
|
|
|
|
|
arg_name] if name == "UnderlyingReader" else startup_block.vars[
|
|
|
|
|
arg_name]
|
|
|
|
|
test_arg_vars.append(arg_var)
|
|
|
|
|
test_op_inputs[name] = test_arg_vars
|
|
|
|
|
|
|
|
|
|
test_op = startup_block.append_op(
|
|
|
|
|
type=op.type,
|
|
|
|
|
inputs=test_op_inputs,
|
|
|
|
|
outputs={'Out': [test_reader]},
|
|
|
|
|
attrs=op.attrs)
|
|
|
|
|
# root reader op's filelist attr for read test files
|
|
|
|
|
if op.type == root_reader_op.type:
|
|
|
|
|
test_op.set_attr("file_names", filelist)
|
|
|
|
|
if op.type == "create_multi_pass_reader":
|
|
|
|
|
test_op.set_attr("pass_num", 1)
|
|
|
|
|
|
|
|
|
|
# 3. rename reader vars in inference program to different name
|
|
|
|
|
# to avoid read from train data.
|
|
|
|
|
main_block = program.global_block()
|
|
|
|
|
for var in main_block.vars.values():
|
|
|
|
|
if var.type == core.VarDesc.VarType.READER:
|
|
|
|
|
main_block.rename_var(
|
|
|
|
|
str(var.name), str(_get_test_reader_name(var.name)))
|
|
|
|
|
|
|
|
|
|
for op in main_block.ops:
|
|
|
|
|
if op.type == root_reader_op.type:
|
|
|
|
|
test_op.set_attr("file_names", filelist)
|
|
|
|
|
if op.type == "create_multi_pass_reader":
|
|
|
|
|
test_op.set_attr("pass_num", 1)
|
|
|
|
|
|
|
|
|
|
startup_program.sync_with_cpp()
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
return program
|
|
|
|
|