Add Decorator 'dygraph_to_static_program' and ProgramTranslator.save_inference_model (#23227)
	
		
	
				
					
				
			1. Add Decorator 'dygraph_to_static_program' 2. Add corresponding ProgramTranslator.get_program 3. Add ProgramTranslator.save_inference_model 4. Modified some warning information of dy2stat 5. Change program cache to contain startup_program because for users who gets program to run, they may like to initialize startup programrevert-23830-2.0-beta
							parent
							
								
									a647bcd355
								
							
						
					
					
						commit
						e5af90aa28
					
				@ -0,0 +1,75 @@
 | 
				
			||||
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
 | 
				
			||||
import os
 | 
				
			||||
import unittest
 | 
				
			||||
 | 
				
			||||
import numpy as np
 | 
				
			||||
import paddle.fluid as fluid
 | 
				
			||||
 | 
				
			||||
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
 | 
				
			||||
from paddle.fluid.dygraph.jit import dygraph_to_static_output
 | 
				
			||||
 | 
				
			||||
np.random.seed(2020)
 | 
				
			||||
 | 
				
			||||
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class SimpleFcLayer(fluid.dygraph.Layer):
 | 
				
			||||
    def __init__(self, fc_size):
 | 
				
			||||
        super(SimpleFcLayer, self).__init__()
 | 
				
			||||
        self._linear = fluid.dygraph.Linear(fc_size, fc_size)
 | 
				
			||||
 | 
				
			||||
    @dygraph_to_static_output
 | 
				
			||||
    def forward(self, x):
 | 
				
			||||
        x = fluid.dygraph.to_variable(x)
 | 
				
			||||
        y = self._linear(x)
 | 
				
			||||
        z = self._linear(y)
 | 
				
			||||
        out = fluid.layers.mean(z, name='mean')
 | 
				
			||||
        return out
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestDyToStaticSaveInferenceModel(unittest.TestCase):
 | 
				
			||||
    def test_save_inference_model(self):
 | 
				
			||||
        fc_size = 20
 | 
				
			||||
 | 
				
			||||
        x = np.random.random((fc_size, fc_size)).astype('float32')
 | 
				
			||||
        layer = SimpleFcLayer(fc_size)
 | 
				
			||||
 | 
				
			||||
        program_translator = ProgramTranslator.get_instance()
 | 
				
			||||
        program_cache = ProgramTranslator().get_program_cache
 | 
				
			||||
        adam = fluid.optimizer.SGD(learning_rate=0.001)
 | 
				
			||||
        program_translator.set_optimizer(adam, 'mean')
 | 
				
			||||
 | 
				
			||||
        for i in range(5):
 | 
				
			||||
            out = layer(x)
 | 
				
			||||
 | 
				
			||||
        main_program = ProgramTranslator.get_instance().main_program
 | 
				
			||||
        expected_persistable_vars = set(
 | 
				
			||||
            [layer._linear.weight.name, layer._linear.bias.name])
 | 
				
			||||
 | 
				
			||||
        infer_model_dir = "./test_dy2stat_save_inference_model"
 | 
				
			||||
        ProgramTranslator.get_instance().save_inference_model(infer_model_dir)
 | 
				
			||||
        saved_var_names = set([
 | 
				
			||||
            filename for filename in os.listdir(infer_model_dir)
 | 
				
			||||
            if filename != '__model__'
 | 
				
			||||
        ])
 | 
				
			||||
        self.assertEqual(saved_var_names, expected_persistable_vars)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == '__main__':
 | 
				
			||||
    unittest.main()
 | 
				
			||||
@ -0,0 +1,83 @@
 | 
				
			||||
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
 | 
				
			||||
import unittest
 | 
				
			||||
 | 
				
			||||
import numpy as np
 | 
				
			||||
import paddle.fluid as fluid
 | 
				
			||||
import paddle.fluid.framework as framework
 | 
				
			||||
 | 
				
			||||
from paddle.fluid.dygraph.jit import dygraph_to_static_program
 | 
				
			||||
from paddle.fluid.dygraph.nn import Linear
 | 
				
			||||
 | 
				
			||||
np.random.seed(2020)
 | 
				
			||||
 | 
				
			||||
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
 | 
				
			||||
)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def simple_func(x, weight_numpy):
 | 
				
			||||
    weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
 | 
				
			||||
    linear = Linear(32, 64, param_attr=weight_initalizer)
 | 
				
			||||
    x = fluid.dygraph.to_variable(x)
 | 
				
			||||
    y = linear(x)
 | 
				
			||||
    z = linear(x)
 | 
				
			||||
    return z
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
@dygraph_to_static_program
 | 
				
			||||
def decorated_simple_func(x, weight_numpy):
 | 
				
			||||
    weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
 | 
				
			||||
    linear = Linear(32, 64, param_attr=weight_initalizer)
 | 
				
			||||
    x = fluid.dygraph.to_variable(x)
 | 
				
			||||
    y = linear(x)
 | 
				
			||||
    z = linear(x)
 | 
				
			||||
    return z
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestDyToStaticSaveLoad(unittest.TestCase):
 | 
				
			||||
    def test_save_load_same_result(self):
 | 
				
			||||
        x = np.random.randn(30, 10, 32).astype('float32')
 | 
				
			||||
        weight = np.random.randn(32, 64).astype('float32')
 | 
				
			||||
        with fluid.dygraph.guard(place):
 | 
				
			||||
            dygraph_result = simple_func(x, weight)
 | 
				
			||||
 | 
				
			||||
        main_program, startup_program, inputs, outputs = decorated_simple_func(
 | 
				
			||||
            x, weight)
 | 
				
			||||
        exe = fluid.Executor(place)
 | 
				
			||||
        exe.run(startup_program)
 | 
				
			||||
        fluid.save(main_program, "./test_dy2stat_save_load")
 | 
				
			||||
 | 
				
			||||
        # set vars to zero so that we can test load in same file
 | 
				
			||||
        for var in main_program.list_vars():
 | 
				
			||||
            if isinstance(var, framework.Parameter) or var.persistable:
 | 
				
			||||
                tensor = fluid.global_scope().find_var(var.name).get_tensor()
 | 
				
			||||
                tensor.set(np.zeros_like(np.array(tensor)), place)
 | 
				
			||||
 | 
				
			||||
                # make sure all the paramerter or optimizer var have been set to zero
 | 
				
			||||
                tensor_np = np.array(fluid.global_scope().find_var(var.name)
 | 
				
			||||
                                     .get_tensor())
 | 
				
			||||
                self.assertEqual(0, np.sum(np.abs(tensor_np)))
 | 
				
			||||
 | 
				
			||||
        fluid.load(main_program, "./test_dy2stat_save_load")
 | 
				
			||||
        static_result = exe.run(main_program,
 | 
				
			||||
                                feed={inputs[0].name: x},
 | 
				
			||||
                                fetch_list=outputs)
 | 
				
			||||
        self.assertTrue(np.allclose(dygraph_result.numpy(), static_result))
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == '__main__':
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue