|
|
|
@ -21,7 +21,7 @@ import numpy as np
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
__all__ = ['DyGraphProgramDescTracerTestHelper', ]
|
|
|
|
|
__all__ = ['DyGraphProgramDescTracerTestHelper', 'is_equal_program']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_equal_program(prog1, prog2):
|
|
|
|
@ -107,74 +107,8 @@ def load_dygraph_vars_to_scope(model_path, scope, place):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DyGraphProgramDescTracerTestHelper(object):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
module,
|
|
|
|
|
unittest_obj,
|
|
|
|
|
model_path=None,
|
|
|
|
|
scope=None,
|
|
|
|
|
place=None):
|
|
|
|
|
self.module = module
|
|
|
|
|
def __init__(self, unittest_obj):
|
|
|
|
|
self.unittest_obj = unittest_obj
|
|
|
|
|
self.scope = fluid.Scope() if scope is None else scope
|
|
|
|
|
|
|
|
|
|
self.model_path = model_path
|
|
|
|
|
if model_path is None:
|
|
|
|
|
millis = int(round(time.time() * 1000))
|
|
|
|
|
self.model_path = "id_{}_{}".format(id(module), millis)
|
|
|
|
|
|
|
|
|
|
self.place = place
|
|
|
|
|
if place is None:
|
|
|
|
|
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
|
|
|
|
) else fluid.CPUPlace()
|
|
|
|
|
|
|
|
|
|
self.program = None
|
|
|
|
|
|
|
|
|
|
self.executor = fluid.Executor(self.place)
|
|
|
|
|
|
|
|
|
|
def _remove_model_path(self):
|
|
|
|
|
if os.path.exists(self.model_path + ".pdparams"):
|
|
|
|
|
os.remove(self.model_path + ".pdparams")
|
|
|
|
|
|
|
|
|
|
if os.path.exists(self.model_path + ".pdopt"):
|
|
|
|
|
os.remove(self.model_path + ".pdopt")
|
|
|
|
|
|
|
|
|
|
def _run_static_graph(self, inputs, feed_names, fetch_names):
|
|
|
|
|
var_list = extract_vars(inputs)
|
|
|
|
|
assert len(var_list) == len(feed_names)
|
|
|
|
|
|
|
|
|
|
feed_dict = {}
|
|
|
|
|
for name, var in zip(feed_names, var_list):
|
|
|
|
|
feed_dict[name] = np.array(var.value().get_tensor())
|
|
|
|
|
|
|
|
|
|
with fluid.scope_guard(self.scope):
|
|
|
|
|
with _dygraph_guard(None):
|
|
|
|
|
return self.executor.run(self.program,
|
|
|
|
|
feed=feed_dict,
|
|
|
|
|
fetch_list=fetch_names)
|
|
|
|
|
|
|
|
|
|
def run(self, inputs, feed_names, fetch_names):
|
|
|
|
|
out_dygraph, program = jit.trace(
|
|
|
|
|
self.module, inputs, feed_names=feed_names, fetch_names=fetch_names)
|
|
|
|
|
|
|
|
|
|
if self.program is not None:
|
|
|
|
|
self.unittest_obj.assertTrue(
|
|
|
|
|
is_equal_program(self.program, program))
|
|
|
|
|
|
|
|
|
|
self.program = program
|
|
|
|
|
|
|
|
|
|
fluid.save_dygraph(self.module.state_dict(), self.model_path)
|
|
|
|
|
load_dygraph_vars_to_scope(self.model_path, self.scope, self.place)
|
|
|
|
|
|
|
|
|
|
self._remove_model_path()
|
|
|
|
|
|
|
|
|
|
out_static_graph = self._run_static_graph(inputs, feed_names,
|
|
|
|
|
fetch_names)
|
|
|
|
|
|
|
|
|
|
if not isinstance(out_dygraph, (list, tuple)):
|
|
|
|
|
assert len(out_static_graph) == 1
|
|
|
|
|
out_static_graph = out_static_graph[0]
|
|
|
|
|
|
|
|
|
|
return out_dygraph, out_static_graph
|
|
|
|
|
|
|
|
|
|
def assertEachVar(self, out_dygraph, out_static_graph, func=None):
|
|
|
|
|
if func is None:
|
|
|
|
|