Add Decorator 'dygraph_to_static_program' and ProgramTranslator.save_inference_model (#23227)
1. Add Decorator 'dygraph_to_static_program' 2. Add corresponding ProgramTranslator.get_program 3. Add ProgramTranslator.save_inference_model 4. Modified some warning information of dy2stat 5. Change program cache to contain startup_program because for users who gets program to run, they may like to initialize startup programrevert-23830-2.0-beta
parent
a647bcd355
commit
e5af90aa28
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
|
||||
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
|
||||
from paddle.fluid.dygraph.jit import dygraph_to_static_output
|
||||
|
||||
np.random.seed(2020)
|
||||
|
||||
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
|
||||
)
|
||||
|
||||
|
||||
class SimpleFcLayer(fluid.dygraph.Layer):
|
||||
def __init__(self, fc_size):
|
||||
super(SimpleFcLayer, self).__init__()
|
||||
self._linear = fluid.dygraph.Linear(fc_size, fc_size)
|
||||
|
||||
@dygraph_to_static_output
|
||||
def forward(self, x):
|
||||
x = fluid.dygraph.to_variable(x)
|
||||
y = self._linear(x)
|
||||
z = self._linear(y)
|
||||
out = fluid.layers.mean(z, name='mean')
|
||||
return out
|
||||
|
||||
|
||||
class TestDyToStaticSaveInferenceModel(unittest.TestCase):
|
||||
def test_save_inference_model(self):
|
||||
fc_size = 20
|
||||
|
||||
x = np.random.random((fc_size, fc_size)).astype('float32')
|
||||
layer = SimpleFcLayer(fc_size)
|
||||
|
||||
program_translator = ProgramTranslator.get_instance()
|
||||
program_cache = ProgramTranslator().get_program_cache
|
||||
adam = fluid.optimizer.SGD(learning_rate=0.001)
|
||||
program_translator.set_optimizer(adam, 'mean')
|
||||
|
||||
for i in range(5):
|
||||
out = layer(x)
|
||||
|
||||
main_program = ProgramTranslator.get_instance().main_program
|
||||
expected_persistable_vars = set(
|
||||
[layer._linear.weight.name, layer._linear.bias.name])
|
||||
|
||||
infer_model_dir = "./test_dy2stat_save_inference_model"
|
||||
ProgramTranslator.get_instance().save_inference_model(infer_model_dir)
|
||||
saved_var_names = set([
|
||||
filename for filename in os.listdir(infer_model_dir)
|
||||
if filename != '__model__'
|
||||
])
|
||||
self.assertEqual(saved_var_names, expected_persistable_vars)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.framework as framework
|
||||
|
||||
from paddle.fluid.dygraph.jit import dygraph_to_static_program
|
||||
from paddle.fluid.dygraph.nn import Linear
|
||||
|
||||
np.random.seed(2020)
|
||||
|
||||
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
|
||||
)
|
||||
|
||||
|
||||
def simple_func(x, weight_numpy):
|
||||
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
|
||||
linear = Linear(32, 64, param_attr=weight_initalizer)
|
||||
x = fluid.dygraph.to_variable(x)
|
||||
y = linear(x)
|
||||
z = linear(x)
|
||||
return z
|
||||
|
||||
|
||||
@dygraph_to_static_program
|
||||
def decorated_simple_func(x, weight_numpy):
|
||||
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
|
||||
linear = Linear(32, 64, param_attr=weight_initalizer)
|
||||
x = fluid.dygraph.to_variable(x)
|
||||
y = linear(x)
|
||||
z = linear(x)
|
||||
return z
|
||||
|
||||
|
||||
class TestDyToStaticSaveLoad(unittest.TestCase):
|
||||
def test_save_load_same_result(self):
|
||||
x = np.random.randn(30, 10, 32).astype('float32')
|
||||
weight = np.random.randn(32, 64).astype('float32')
|
||||
with fluid.dygraph.guard(place):
|
||||
dygraph_result = simple_func(x, weight)
|
||||
|
||||
main_program, startup_program, inputs, outputs = decorated_simple_func(
|
||||
x, weight)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_program)
|
||||
fluid.save(main_program, "./test_dy2stat_save_load")
|
||||
|
||||
# set vars to zero so that we can test load in same file
|
||||
for var in main_program.list_vars():
|
||||
if isinstance(var, framework.Parameter) or var.persistable:
|
||||
tensor = fluid.global_scope().find_var(var.name).get_tensor()
|
||||
tensor.set(np.zeros_like(np.array(tensor)), place)
|
||||
|
||||
# make sure all the paramerter or optimizer var have been set to zero
|
||||
tensor_np = np.array(fluid.global_scope().find_var(var.name)
|
||||
.get_tensor())
|
||||
self.assertEqual(0, np.sum(np.abs(tensor_np)))
|
||||
|
||||
fluid.load(main_program, "./test_dy2stat_save_load")
|
||||
static_result = exe.run(main_program,
|
||||
feed={inputs[0].name: x},
|
||||
fetch_list=outputs)
|
||||
self.assertTrue(np.allclose(dygraph_result.numpy(), static_result))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue