Refine Model of high level API (#25559)
* Refine Model 1. Take the network (instance of Layer) as the input of Model. 2. Refine set_dict/load_dict of Layer. 3. Refine Input interface, so update code sample about Inputfix_copy_if_different
parent
4152d39962
commit
b5f8784cab
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import six
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.dygraph.parallel import ParallelEnv
|
||||
|
||||
__all__ = ['set_device', ]
|
||||
|
||||
# TODO(qingqing01): remove or refine _global_device, set_device and get_device
|
||||
# after core framework supporting these function.
|
||||
_global_device = None
|
||||
|
||||
|
||||
def set_device(device):
|
||||
"""
|
||||
Args:
|
||||
device (str): specify device type, 'cpu' or 'gpu'.
|
||||
|
||||
Returns:
|
||||
fluid.CUDAPlace or fluid.CPUPlace: Created GPU or CPU place.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle.incubate.hapi as hapi
|
||||
|
||||
input = hapi.set_device('gpu')
|
||||
"""
|
||||
|
||||
assert isinstance(device, six.string_types) and device.lower() in ['cpu', 'gpu'], \
|
||||
"Expected device in ['cpu', 'gpu'], but got {}".format(device)
|
||||
|
||||
device = fluid.CUDAPlace(ParallelEnv().dev_id) \
|
||||
if device.lower() == 'gpu' and fluid.is_compiled_with_cuda() \
|
||||
else fluid.CPUPlace()
|
||||
|
||||
global _global_device
|
||||
_global_device = device
|
||||
return device
|
||||
|
||||
|
||||
def _get_device():
|
||||
"""
|
||||
Return global device.
|
||||
"""
|
||||
if _global_device is not None:
|
||||
device = _global_device
|
||||
else:
|
||||
if fluid.is_compiled_with_cuda():
|
||||
device = fluid.CUDAPlace(ParallelEnv().dev_id)
|
||||
else:
|
||||
device = fluid.CPUPlace()
|
||||
return device
|
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.framework import in_dygraph_mode
|
||||
|
||||
from .device import _get_device
|
||||
|
||||
|
||||
def monkey_patch_layer():
|
||||
def load_dict(self,
|
||||
stat_dict,
|
||||
include_sublayers=True,
|
||||
use_structured_name=True):
|
||||
'''
|
||||
Set parameters from stat_dict. All the parameters will be reset by the
|
||||
tensor in the stat_dict
|
||||
|
||||
This api will be Deprecated. Please use set_dict
|
||||
|
||||
Parameters:
|
||||
state_dict(dict) : Dict contains all the parameters
|
||||
include_sublayers(bool, optional) : If true, also include the
|
||||
parameters from sublayers. Default: True
|
||||
use_structured_name(bool, optional) : If true, use structured name
|
||||
as key, otherwise, use parameter name as key. Default: True
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle.fluid as fluid
|
||||
with fluid.dygraph.guard():
|
||||
emb = fluid.dygraph.Embedding([10, 10])
|
||||
|
||||
state_dict = emb.state_dict()
|
||||
fluid.save_dygraph( state_dict, "paddle_dy")
|
||||
|
||||
para_state_dict, _ = fluid.load_dygraph( "paddle_dy")
|
||||
emb.load_dict( para_state_dict )
|
||||
|
||||
'''
|
||||
|
||||
def _check_match(key, param):
|
||||
state = stat_dict.get(key, None)
|
||||
if state is None:
|
||||
raise ValueError(
|
||||
"{} is not found in the providing file.".format(key))
|
||||
if list(state.shape) != list(param.shape):
|
||||
raise ValueError(
|
||||
"{} receives a shape {}, but the expected shape is {}.".
|
||||
format(key, list(state.shape), list(param.shape)))
|
||||
return param, state
|
||||
|
||||
matched_param_state = []
|
||||
for key, param in self.state_dict().items():
|
||||
key_name = key if use_structured_name else param.name
|
||||
try:
|
||||
match_res = _check_match(key_name, param)
|
||||
except ValueError as err:
|
||||
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
|
||||
matched_param_state.append(match_res)
|
||||
|
||||
if in_dygraph_mode():
|
||||
for param, state in matched_param_state:
|
||||
param.set_value(state)
|
||||
else:
|
||||
|
||||
def _set_var(var, ndarray):
|
||||
t = fluid.global_scope().find_var(var.name).get_tensor()
|
||||
p = t._place()
|
||||
if p.is_cpu_place():
|
||||
place = fluid.CPUPlace()
|
||||
elif p.is_cuda_pinned_place():
|
||||
place = fluid.CUDAPinnedPlace()
|
||||
else:
|
||||
p = fluid.core.Place()
|
||||
p.set_place(t._place())
|
||||
place = fluid.CUDAPlace(p.gpu_device_id())
|
||||
t.set(ndarray, place)
|
||||
|
||||
executor = fluid.Executor(_get_device())._default_executor
|
||||
# restore parameter states
|
||||
fluid.core._create_loaded_parameter(
|
||||
[param for param, state in matched_param_state],
|
||||
fluid.global_scope(), executor)
|
||||
for param, state in matched_param_state:
|
||||
_set_var(param, state)
|
||||
|
||||
setattr(fluid.dygraph.Layer, 'load_dict', load_dict)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.incubate.hapi.vision.models as models
|
||||
from paddle.incubate.hapi import Model, Input
|
||||
|
||||
|
||||
# test the predicted resutls of static graph and dynamic graph are equal
|
||||
# when used pretrained model
|
||||
class TestPretrainedModel(unittest.TestCase):
|
||||
def infer(self, x, arch, dygraph=True):
|
||||
if dygraph:
|
||||
fluid.enable_dygraph()
|
||||
|
||||
net = models.__dict__[arch](pretrained=True, classifier_activation=None)
|
||||
inputs = [Input('image', [None, 3, 224, 224], 'float32')]
|
||||
model = Model(network=net, inputs=inputs)
|
||||
model.prepare()
|
||||
res = model.test_batch(x)
|
||||
|
||||
if dygraph:
|
||||
fluid.disable_dygraph()
|
||||
return res
|
||||
|
||||
def test_models(self):
|
||||
arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18']
|
||||
for arch in arches:
|
||||
x = np.array(np.random.random((2, 3, 224, 224)), dtype=np.float32)
|
||||
y_dygraph = self.infer(x, arch)
|
||||
y_static = self.infer(x, arch, dygraph=False)
|
||||
np.testing.assert_allclose(y_dygraph, y_static)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue