|
|
|
@ -28,19 +28,25 @@ np.random.seed(SEED)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dyfunc_to_variable(x):
|
|
|
|
|
res = fluid.dygraph.to_variable(x)
|
|
|
|
|
res = fluid.dygraph.to_variable(x, name=None, zero_copy=None)
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dyfunc_to_variable_2(x):
|
|
|
|
|
res = fluid.dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32))
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDygraphBasicApi_ToVariable(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.input = np.ones(5).astype("int32")
|
|
|
|
|
self.dygraph_func = dyfunc_to_variable
|
|
|
|
|
self.test_funcs = [dyfunc_to_variable, dyfunc_to_variable_2]
|
|
|
|
|
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
|
|
|
|
) else fluid.CPUPlace()
|
|
|
|
|
|
|
|
|
|
def get_dygraph_output(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
res = self.dygraph_func(self.input).numpy()
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def get_static_output(self):
|
|
|
|
@ -49,18 +55,20 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase):
|
|
|
|
|
with fluid.program_guard(main_program):
|
|
|
|
|
static_out = dygraph_to_static_graph(self.dygraph_func)(self.input)
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
exe = fluid.Executor(self.place)
|
|
|
|
|
static_res = exe.run(main_program, fetch_list=static_out)
|
|
|
|
|
|
|
|
|
|
return static_res[0]
|
|
|
|
|
|
|
|
|
|
def test_transformed_static_result(self):
|
|
|
|
|
dygraph_res = self.get_dygraph_output()
|
|
|
|
|
static_res = self.get_static_output()
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(dygraph_res, static_res),
|
|
|
|
|
msg='dygraph is {}\n static_res is {}'.format(dygraph_res,
|
|
|
|
|
static_res))
|
|
|
|
|
for func in self.test_funcs:
|
|
|
|
|
self.dygraph_func = func
|
|
|
|
|
dygraph_res = self.get_dygraph_output()
|
|
|
|
|
static_res = self.get_static_output()
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(dygraph_res, static_res),
|
|
|
|
|
msg='dygraph is {}\n static_res is {}'.format(dygraph_res,
|
|
|
|
|
static_res))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 1. test Apis that inherit from layers.Layer
|
|
|
|
|