|
|
@ -24,7 +24,8 @@ from paddle.fluid.layer_helper import LayerHelper
|
|
|
|
from paddle.fluid import Conv2D, Pool2D, BatchNorm, FC
|
|
|
|
from paddle.fluid import Conv2D, Pool2D, BatchNorm, FC
|
|
|
|
from paddle.fluid.dygraph.base import to_variable
|
|
|
|
from paddle.fluid.dygraph.base import to_variable
|
|
|
|
from test_imperative_base import new_program_scope
|
|
|
|
from test_imperative_base import new_program_scope
|
|
|
|
from utils import DyGraphProgramDescTracerTestHelper
|
|
|
|
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
|
|
|
|
|
|
|
|
from paddle.fluid.dygraph.jit import TracedLayer
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = 8
|
|
|
|
batch_size = 8
|
|
|
|
train_parameters = {
|
|
|
|
train_parameters = {
|
|
|
@ -227,6 +228,8 @@ class TestDygraphResnet(unittest.TestCase):
|
|
|
|
batch_size = train_parameters["batch_size"]
|
|
|
|
batch_size = train_parameters["batch_size"]
|
|
|
|
batch_num = 10
|
|
|
|
batch_num = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
traced_layer = None
|
|
|
|
|
|
|
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
@ -250,7 +253,8 @@ class TestDygraphResnet(unittest.TestCase):
|
|
|
|
for param in resnet.parameters():
|
|
|
|
for param in resnet.parameters():
|
|
|
|
dy_param_init_value[param.name] = param.numpy()
|
|
|
|
dy_param_init_value[param.name] = param.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
helper = DyGraphProgramDescTracerTestHelper(resnet, self)
|
|
|
|
helper = DyGraphProgramDescTracerTestHelper(self)
|
|
|
|
|
|
|
|
program = None
|
|
|
|
|
|
|
|
|
|
|
|
for batch_id, data in enumerate(batch_py_reader()):
|
|
|
|
for batch_id, data in enumerate(batch_py_reader()):
|
|
|
|
if batch_id >= batch_num:
|
|
|
|
if batch_id >= batch_num:
|
|
|
@ -260,14 +264,29 @@ class TestDygraphResnet(unittest.TestCase):
|
|
|
|
label = data[1]
|
|
|
|
label = data[1]
|
|
|
|
label.stop_gradient = True
|
|
|
|
label.stop_gradient = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = None
|
|
|
|
if batch_id % 5 == 0:
|
|
|
|
if batch_id % 5 == 0:
|
|
|
|
out, out_static = helper.run(img,
|
|
|
|
out, traced_layer = TracedLayer.trace(resnet, img)
|
|
|
|
feed_names=['image'],
|
|
|
|
if program is not None:
|
|
|
|
fetch_names=['logits'])
|
|
|
|
self.assertTrue(
|
|
|
|
helper.assertEachVar(out, out_static)
|
|
|
|
is_equal_program(program, traced_layer.program))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
traced_layer.save_inference_model(
|
|
|
|
|
|
|
|
'./infer_imperative_resnet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
program = traced_layer.program
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
out = resnet(img)
|
|
|
|
out = resnet(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if traced_layer is not None:
|
|
|
|
|
|
|
|
resnet.eval()
|
|
|
|
|
|
|
|
traced_layer._switch(is_test=True)
|
|
|
|
|
|
|
|
out_dygraph = resnet([img])
|
|
|
|
|
|
|
|
out_static = traced_layer([img])
|
|
|
|
|
|
|
|
traced_layer._switch(is_test=False)
|
|
|
|
|
|
|
|
helper.assertEachVar(out_dygraph, out_static)
|
|
|
|
|
|
|
|
resnet.train()
|
|
|
|
|
|
|
|
|
|
|
|
loss = fluid.layers.cross_entropy(input=out, label=label)
|
|
|
|
loss = fluid.layers.cross_entropy(input=out, label=label)
|
|
|
|
avg_loss = fluid.layers.mean(x=loss)
|
|
|
|
avg_loss = fluid.layers.mean(x=loss)
|
|
|
|
|
|
|
|
|
|
|
@ -346,6 +365,9 @@ class TestDygraphResnet(unittest.TestCase):
|
|
|
|
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
|
|
|
|
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
|
|
|
|
[batch_size, 1])
|
|
|
|
[batch_size, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if traced_layer is not None:
|
|
|
|
|
|
|
|
traced_layer([static_x_data])
|
|
|
|
|
|
|
|
|
|
|
|
fetch_list = [avg_loss.name]
|
|
|
|
fetch_list = [avg_loss.name]
|
|
|
|
fetch_list.extend(static_param_name_list)
|
|
|
|
fetch_list.extend(static_param_name_list)
|
|
|
|
fetch_list.extend(static_grad_name_list)
|
|
|
|
fetch_list.extend(static_grad_name_list)
|
|
|
|