Add friendly dygraph trace API (#21091)

* friendly trace interface, test=develop

* refine TracedLayer, test=develop

* add some docs, test=develop
revert-21172-masked_select_api
Zeng Jinle 5 years ago committed by GitHub
parent 44a0a4adcc
commit 5fdfbe3413
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -528,6 +528,7 @@ PYBIND11_MODULE(core_noavx, m) {
.def("_get_double_element", TensorGetElement<double>) .def("_get_double_element", TensorGetElement<double>)
.def("_place", [](Tensor &self) { return self.place(); }) .def("_place", [](Tensor &self) { return self.place(); })
.def("_dtype", [](Tensor &self) { return self.type(); }) .def("_dtype", [](Tensor &self) { return self.type(); })
.def("_share_data_with", &Tensor::ShareDataWith)
.def("__getitem__", PySliceTensor, py::return_value_policy::reference) .def("__getitem__", PySliceTensor, py::return_value_policy::reference)
.def("__str__", [](const Tensor &self) { .def("__str__", [](const Tensor &self) {
std::stringstream ostr; std::stringstream ostr;

@ -27,6 +27,17 @@ __all__ = [
] ]
def _switch_to_static_graph_(func):
def __impl__(*args, **kwargs):
with framework._dygraph_guard(None):
return func(*args, **kwargs)
return __impl__
switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager @signature_safe_contextmanager
def program_desc_tracing_guard(enable): def program_desc_tracing_guard(enable):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()

File diff suppressed because it is too large Load Diff

@ -26,7 +26,8 @@ from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
from paddle.fluid.dygraph.jit import TracedLayer
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
@ -119,6 +120,8 @@ class TestImperativeMnist(unittest.TestCase):
batch_size = 128 batch_size = 128
batch_num = 50 batch_num = 50
traced_layer = None
with fluid.dygraph.guard(): with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
@ -137,8 +140,8 @@ class TestImperativeMnist(unittest.TestCase):
mnist.train() mnist.train()
dy_param_init_value = {} dy_param_init_value = {}
helper = DyGraphProgramDescTracerTestHelper(mnist, self) helper = DyGraphProgramDescTracerTestHelper(self)
program = None
for epoch in range(epoch_num): for epoch in range(epoch_num):
for batch_id, data in enumerate(batch_py_reader()): for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= batch_num: if batch_id >= batch_num:
@ -149,13 +152,20 @@ class TestImperativeMnist(unittest.TestCase):
label.stop_gradient = True label.stop_gradient = True
if batch_id % 10 == 0: if batch_id % 10 == 0:
cost, cost_static = helper.run(inputs=img, cost, traced_layer = TracedLayer.trace(
feed_names=['image'], mnist, inputs=img)
fetch_names=['cost']) if program is not None:
helper.assertEachVar(cost, cost_static) self.assertTrue(program, traced_layer.program)
program = traced_layer.program
traced_layer.save_inference_model(
'./infer_imperative_mnist')
else: else:
cost = mnist(img) cost = mnist(img)
if traced_layer is not None:
cost_static = traced_layer([img])
helper.assertEachVar(cost, cost_static)
loss = fluid.layers.cross_entropy(cost, label) loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss) avg_loss = fluid.layers.mean(loss)
@ -220,6 +230,10 @@ class TestImperativeMnist(unittest.TestCase):
fetch_list = [avg_loss.name] fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list) fetch_list.extend(static_param_name_list)
if traced_layer is not None:
traced_layer([static_x_data])
out = exe.run( out = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={"pixel": static_x_data, feed={"pixel": static_x_data,

@ -21,10 +21,11 @@ from paddle.fluid.dygraph.nn import Embedding
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.jit import TracedLayer
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
import numpy as np import numpy as np
import six import six
from utils import DyGraphProgramDescTracerTestHelper from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
class SimpleLSTMRNN(fluid.Layer): class SimpleLSTMRNN(fluid.Layer):
@ -221,6 +222,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
batch_size = 4 batch_size = 4
batch_num = 200 batch_num = 200
traced_layer = None
with fluid.dygraph.guard(): with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
@ -240,7 +243,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
last_hidden = None last_hidden = None
last_cell = None last_cell = None
helper = DyGraphProgramDescTracerTestHelper(ptb_model, self) helper = DyGraphProgramDescTracerTestHelper(self)
program = None
for i in range(batch_num): for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
@ -256,11 +260,19 @@ class TestDygraphPtbRnn(unittest.TestCase):
init_hidden = to_variable(init_hidden_data) init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data) init_cell = to_variable(init_cell_data)
if i % 5 == 0: if i % 5 == 0:
outs, outs_static = helper.run( outs, traced_layer = TracedLayer.trace(
[x, y, init_hidden, init_cell], ptb_model, [x, y, init_hidden, init_cell])
feed_names=['x', 'y', 'init_hidden', 'init_cell'], outs_static = traced_layer([x, y, init_hidden, init_cell])
fetch_names=['dy_loss', 'last_hidden', 'last_cell'])
helper.assertEachVar(outs, outs_static) helper.assertEachVar(outs, outs_static)
if program is not None:
self.assertTrue(
is_equal_program(traced_layer.program, program))
program = traced_layer.program
traced_layer.save_inference_model(
'./infe_imperative_ptb_rnn', feed=range(4))
else: else:
outs = ptb_model(x, y, init_hidden, init_cell) outs = ptb_model(x, y, init_hidden, init_cell)

@ -24,7 +24,8 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import Conv2D, Pool2D, BatchNorm, FC from paddle.fluid import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
from paddle.fluid.dygraph.jit import TracedLayer
batch_size = 8 batch_size = 8
train_parameters = { train_parameters = {
@ -227,6 +228,8 @@ class TestDygraphResnet(unittest.TestCase):
batch_size = train_parameters["batch_size"] batch_size = train_parameters["batch_size"]
batch_num = 10 batch_num = 10
traced_layer = None
with fluid.dygraph.guard(): with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
@ -250,7 +253,8 @@ class TestDygraphResnet(unittest.TestCase):
for param in resnet.parameters(): for param in resnet.parameters():
dy_param_init_value[param.name] = param.numpy() dy_param_init_value[param.name] = param.numpy()
helper = DyGraphProgramDescTracerTestHelper(resnet, self) helper = DyGraphProgramDescTracerTestHelper(self)
program = None
for batch_id, data in enumerate(batch_py_reader()): for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= batch_num: if batch_id >= batch_num:
@ -260,14 +264,29 @@ class TestDygraphResnet(unittest.TestCase):
label = data[1] label = data[1]
label.stop_gradient = True label.stop_gradient = True
out = None
if batch_id % 5 == 0: if batch_id % 5 == 0:
out, out_static = helper.run(img, out, traced_layer = TracedLayer.trace(resnet, img)
feed_names=['image'], if program is not None:
fetch_names=['logits']) self.assertTrue(
helper.assertEachVar(out, out_static) is_equal_program(program, traced_layer.program))
traced_layer.save_inference_model(
'./infer_imperative_resnet')
program = traced_layer.program
else: else:
out = resnet(img) out = resnet(img)
if traced_layer is not None:
resnet.eval()
traced_layer._switch(is_test=True)
out_dygraph = resnet([img])
out_static = traced_layer([img])
traced_layer._switch(is_test=False)
helper.assertEachVar(out_dygraph, out_static)
resnet.train()
loss = fluid.layers.cross_entropy(input=out, label=label) loss = fluid.layers.cross_entropy(input=out, label=label)
avg_loss = fluid.layers.mean(x=loss) avg_loss = fluid.layers.mean(x=loss)
@ -346,6 +365,9 @@ class TestDygraphResnet(unittest.TestCase):
y_data = np.array([x[1] for x in data]).astype('int64').reshape( y_data = np.array([x[1] for x in data]).astype('int64').reshape(
[batch_size, 1]) [batch_size, 1])
if traced_layer is not None:
traced_layer([static_x_data])
fetch_list = [avg_loss.name] fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list) fetch_list.extend(static_param_name_list)
fetch_list.extend(static_grad_name_list) fetch_list.extend(static_grad_name_list)

@ -18,13 +18,14 @@ import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Embedding, LayerNorm, FC, Layer from paddle.fluid import Embedding, LayerNorm, FC, Layer
from paddle.fluid.dygraph import to_variable, guard from paddle.fluid.dygraph import to_variable, guard
from paddle.fluid.dygraph.jit import TracedLayer
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from paddle.fluid import core from paddle.fluid import core
import numpy as np import numpy as np
import six import six
np.set_printoptions(suppress=True) np.set_printoptions(suppress=True)
from utils import DyGraphProgramDescTracerTestHelper from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
# Copy from models # Copy from models
@ -979,23 +980,24 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
dy_param_init = dict() dy_param_init = dict()
dy_param_updated = dict() dy_param_updated = dict()
helper = DyGraphProgramDescTracerTestHelper(transformer, self) helper = DyGraphProgramDescTracerTestHelper(self)
program = None
for i in range(batch_num): for i in range(batch_num):
enc_inputs, dec_inputs, label, weights = create_data() enc_inputs, dec_inputs, label, weights = create_data()
if i % 5 == 0: if i % 2 == 0:
outs, outs_static = helper.run( outs, traced_layer = TracedLayer.trace(
inputs=[enc_inputs, dec_inputs, label, weights], transformer, [enc_inputs, dec_inputs, label, weights])
feed_names=[ outs_static = traced_layer(enc_inputs + dec_inputs +
'enc_input_0', 'enc_input_1', 'enc_input_2', [label, weights])
'dec_input_0', 'dec_input_1', 'dec_input_2',
'dec_input_3', 'label', 'weights'
],
fetch_names=[
'dy_sum_cost', 'dy_avg_cost', 'dy_predict',
'dy_token_num'
])
helper.assertEachVar(outs, outs_static) helper.assertEachVar(outs, outs_static)
if program is not None:
self.assertTrue(
is_equal_program(program, traced_layer.program))
program = traced_layer.program
traced_layer.save_inference_model(
'./infer_imperative_transformer')
else: else:
outs = transformer(enc_inputs, dec_inputs, label, weights) outs = transformer(enc_inputs, dec_inputs, label, weights)

@ -21,7 +21,7 @@ import numpy as np
import os import os
import time import time
__all__ = ['DyGraphProgramDescTracerTestHelper', ] __all__ = ['DyGraphProgramDescTracerTestHelper', 'is_equal_program']
def is_equal_program(prog1, prog2): def is_equal_program(prog1, prog2):
@ -107,74 +107,8 @@ def load_dygraph_vars_to_scope(model_path, scope, place):
class DyGraphProgramDescTracerTestHelper(object): class DyGraphProgramDescTracerTestHelper(object):
def __init__(self, def __init__(self, unittest_obj):
module,
unittest_obj,
model_path=None,
scope=None,
place=None):
self.module = module
self.unittest_obj = unittest_obj self.unittest_obj = unittest_obj
self.scope = fluid.Scope() if scope is None else scope
self.model_path = model_path
if model_path is None:
millis = int(round(time.time() * 1000))
self.model_path = "id_{}_{}".format(id(module), millis)
self.place = place
if place is None:
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.program = None
self.executor = fluid.Executor(self.place)
def _remove_model_path(self):
if os.path.exists(self.model_path + ".pdparams"):
os.remove(self.model_path + ".pdparams")
if os.path.exists(self.model_path + ".pdopt"):
os.remove(self.model_path + ".pdopt")
def _run_static_graph(self, inputs, feed_names, fetch_names):
var_list = extract_vars(inputs)
assert len(var_list) == len(feed_names)
feed_dict = {}
for name, var in zip(feed_names, var_list):
feed_dict[name] = np.array(var.value().get_tensor())
with fluid.scope_guard(self.scope):
with _dygraph_guard(None):
return self.executor.run(self.program,
feed=feed_dict,
fetch_list=fetch_names)
def run(self, inputs, feed_names, fetch_names):
out_dygraph, program = jit.trace(
self.module, inputs, feed_names=feed_names, fetch_names=fetch_names)
if self.program is not None:
self.unittest_obj.assertTrue(
is_equal_program(self.program, program))
self.program = program
fluid.save_dygraph(self.module.state_dict(), self.model_path)
load_dygraph_vars_to_scope(self.model_path, self.scope, self.place)
self._remove_model_path()
out_static_graph = self._run_static_graph(inputs, feed_names,
fetch_names)
if not isinstance(out_dygraph, (list, tuple)):
assert len(out_static_graph) == 1
out_static_graph = out_static_graph[0]
return out_dygraph, out_static_graph
def assertEachVar(self, out_dygraph, out_static_graph, func=None): def assertEachVar(self, out_dygraph, out_static_graph, func=None):
if func is None: if func is None:

Loading…
Cancel
Save