[DygraphToStatic]Add cast transform for dygraph_to_static. (#25325)
* add cast transform and its UT for dygraph_to_static.fix_copy_if_different
parent
bdad383c1d
commit
2989c012f2
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import gast
|
||||||
|
|
||||||
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
|
||||||
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
||||||
|
|
||||||
|
|
||||||
|
class CastTransformer(gast.NodeTransformer):
|
||||||
|
"""
|
||||||
|
This class transforms type casting into Static Graph Ast.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wrapper_root):
|
||||||
|
assert isinstance(
|
||||||
|
wrapper_root, AstNodeWrapper
|
||||||
|
), "Input non-AstNodeWrapper node for the initialization of CastTransformer."
|
||||||
|
self._root = wrapper_root.node
|
||||||
|
self._castable_type = {'bool', 'int', 'float'}
|
||||||
|
|
||||||
|
def transform(self):
|
||||||
|
self.visit(self._root)
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
func_str = ast_to_source_code(node.func).strip()
|
||||||
|
if func_str in self._castable_type and len(node.args) > 0:
|
||||||
|
args_str = ast_to_source_code(node.args[0]).strip()
|
||||||
|
new_func_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_dtype({}, '{}')".format(
|
||||||
|
args_str, func_str)
|
||||||
|
new_node = gast.parse(new_func_str).body[0].value
|
||||||
|
return new_node
|
||||||
|
|
||||||
|
return node
|
||||||
@ -0,0 +1,173 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.dygraph import declarative
|
||||||
|
|
||||||
|
SEED = 2020
|
||||||
|
np.random.seed(SEED)
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_bool_cast(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
x = bool(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_int_cast(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
x = int(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_float_cast(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
x = float(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_not_var_cast(x):
|
||||||
|
x = int(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_mix_cast(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
x = int(x)
|
||||||
|
x = float(x)
|
||||||
|
x = bool(x)
|
||||||
|
x = float(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestCastBase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||||
|
) else fluid.CPUPlace()
|
||||||
|
self.prepare()
|
||||||
|
self.set_func()
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
self.input_shape = (16, 32)
|
||||||
|
self.input_dtype = 'float32'
|
||||||
|
self.input = np.random.binomial(
|
||||||
|
4, 0.3, size=np.product(self.input_shape)).reshape(
|
||||||
|
self.input_shape).astype(self.input_dtype)
|
||||||
|
self.cast_dtype = 'bool'
|
||||||
|
|
||||||
|
def set_func(self):
|
||||||
|
self.func = test_bool_cast
|
||||||
|
|
||||||
|
def do_test(self):
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
res = self.func(self.input)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def test_cast_result(self):
|
||||||
|
res = self.do_test().numpy()
|
||||||
|
self.assertTrue(
|
||||||
|
res.dtype == self.cast_dtype,
|
||||||
|
msg='The target dtype is {}, but the casted dtype is {}.'.format(
|
||||||
|
self.cast_dtype, res.dtype))
|
||||||
|
ref_val = self.input.astype(self.cast_dtype)
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(res, ref_val),
|
||||||
|
msg='The casted value is {}.\nThe correct value is {}.'.format(
|
||||||
|
res, ref_val))
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntCast(TestCastBase):
|
||||||
|
def prepare(self):
|
||||||
|
self.input_shape = (1, )
|
||||||
|
self.input_dtype = 'float32'
|
||||||
|
self.input = np.random.normal(
|
||||||
|
loc=6, scale=10, size=np.product(self.input_shape)).reshape(
|
||||||
|
self.input_shape).astype(self.input_dtype)
|
||||||
|
self.cast_dtype = 'int32'
|
||||||
|
|
||||||
|
def set_func(self):
|
||||||
|
self.func = test_int_cast
|
||||||
|
|
||||||
|
|
||||||
|
class TestFloatCast(TestCastBase):
|
||||||
|
def prepare(self):
|
||||||
|
self.input_shape = (8, 16)
|
||||||
|
self.input_dtype = 'bool'
|
||||||
|
self.input = np.random.binomial(
|
||||||
|
2, 0.5, size=np.product(self.input_shape)).reshape(
|
||||||
|
self.input_shape).astype(self.input_dtype)
|
||||||
|
self.cast_dtype = 'float32'
|
||||||
|
|
||||||
|
def set_func(self):
|
||||||
|
self.func = test_float_cast
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixCast(TestCastBase):
|
||||||
|
def prepare(self):
|
||||||
|
self.input_shape = (8, 32)
|
||||||
|
self.input_dtype = 'float32'
|
||||||
|
self.input = np.random.normal(
|
||||||
|
loc=6, scale=10, size=np.product(self.input_shape)).reshape(
|
||||||
|
self.input_shape).astype(self.input_dtype)
|
||||||
|
self.cast_int = 'int'
|
||||||
|
self.cast_float = 'float32'
|
||||||
|
self.cast_bool = 'bool'
|
||||||
|
self.cast_dtype = 'float32'
|
||||||
|
|
||||||
|
def set_func(self):
|
||||||
|
self.func = test_mix_cast
|
||||||
|
|
||||||
|
def test_cast_result(self):
|
||||||
|
res = self.do_test().numpy()
|
||||||
|
self.assertTrue(
|
||||||
|
res.dtype == self.cast_dtype,
|
||||||
|
msg='The target dtype is {}, but the casted dtype is {}.'.format(
|
||||||
|
self.cast_dtype, res.dtype))
|
||||||
|
ref_val = self.input.astype(self.cast_int).astype(
|
||||||
|
self.cast_float).astype(self.cast_bool).astype(self.cast_dtype)
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(res, ref_val),
|
||||||
|
msg='The casted value is {}.\nThe correct value is {}.'.format(
|
||||||
|
res, ref_val))
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotVarCast(TestCastBase):
|
||||||
|
def prepare(self):
|
||||||
|
self.input = 3.14
|
||||||
|
self.cast_dtype = 'int'
|
||||||
|
|
||||||
|
def set_func(self):
|
||||||
|
self.func = test_not_var_cast
|
||||||
|
|
||||||
|
def test_cast_result(self):
|
||||||
|
res = self.do_test()
|
||||||
|
self.assertTrue(type(res) == int, msg='The casted dtype is not int.')
|
||||||
|
ref_val = int(self.input)
|
||||||
|
self.assertTrue(
|
||||||
|
res == ref_val,
|
||||||
|
msg='The casted value is {}.\nThe correct value is {}.'.format(
|
||||||
|
res, ref_val))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue