[oneDNN] Initial bf16 amp integration (#31093)
parent
a501a7b0ca
commit
7ccf6b6030
@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from . import amp_lists
|
||||
from .amp_lists import *
|
||||
from . import amp_utils
|
||||
from .amp_utils import *
|
||||
|
||||
__all__ = []
|
||||
__all__ += amp_lists.__all__
|
||||
__all__ += amp_utils.__all__
|
@ -0,0 +1,97 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\
|
||||
gray_list as gray_list_fp16, unsupported_fp16_list
|
||||
|
||||
__all__ = ["AutoMixedPrecisionListsBF16"]
|
||||
|
||||
|
||||
class AutoMixedPrecisionListsBF16(object):
|
||||
"""
|
||||
AutoMixedPrecisionListsBF16 is a class for fp32/bf16 op types list. The lists are used for an
|
||||
algorithm which determines op's execution mode (fp32 or bf16).It can update pre-defined
|
||||
fp32 list and bf16 list according to users' custom fp32 bf16 lists.
|
||||
|
||||
Args:
|
||||
custom_bf16_list (set): Users' custom bf16 list.
|
||||
custom_fp32_list (set): Users' custom fp32 list.
|
||||
custom_fp32_varnames (set): Users' custom fp32 variables' names.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
import paddle
|
||||
paddle.enable_static()
|
||||
with paddle.static.amp.bf16_guard():
|
||||
paddle.static.amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'})
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
custom_bf16_list=None,
|
||||
custom_fp32_list=None,
|
||||
custom_fp32_varnames=None):
|
||||
self._custom_bf16_list = custom_bf16_list
|
||||
self._custom_fp32_list = custom_fp32_list
|
||||
self.bf16_list = copy.copy(bf16_list)
|
||||
self.fp32_list = copy.copy(fp32_list)
|
||||
self.gray_list = copy.copy(gray_list)
|
||||
self.unsupported_list = copy.copy(unsupported_list)
|
||||
self.fp32_varnames = copy.copy(custom_fp32_varnames)
|
||||
self._update_list()
|
||||
|
||||
def _update_list(self):
|
||||
"""
|
||||
Update fp32 and bf16 list according to users' custom list.
|
||||
"""
|
||||
if self._custom_bf16_list and self._custom_fp32_list:
|
||||
for op_name in self._custom_bf16_list:
|
||||
if op_name in self._custom_fp32_list:
|
||||
raise ValueError("Custom bf16 list overlap "
|
||||
"custom fp32 list")
|
||||
if self._custom_bf16_list:
|
||||
for op_name in self._custom_bf16_list:
|
||||
if op_name in self.fp32_list:
|
||||
self.fp32_list.remove(op_name)
|
||||
elif op_name in self.gray_list:
|
||||
self.gray_list.remove(op_name)
|
||||
self.bf16_list.add(op_name)
|
||||
if self._custom_fp32_list:
|
||||
for op_name in self._custom_fp32_list:
|
||||
if op_name in self.bf16_list:
|
||||
self.bf16_list.remove(op_name)
|
||||
elif op_name in self.gray_list:
|
||||
self.gray_list.remove(op_name)
|
||||
self.fp32_list.add(op_name)
|
||||
self.unsupported_list.add(op_name)
|
||||
|
||||
|
||||
# always bf16
|
||||
bf16_list = {'elementwise_add', }
|
||||
|
||||
# depends on the prev_op type
|
||||
gray_list = {
|
||||
'reshape2',
|
||||
'lookup_table',
|
||||
}
|
||||
|
||||
unsupported_list = unsupported_fp16_list.copy().copy()
|
||||
fp32_list = black_list_fp16.copy().copy()
|
||||
fp32_list |= white_list_fp16
|
||||
fp32_list |= gray_list_fp16
|
||||
|
||||
fp32_list -= bf16_list
|
||||
fp32_list -= gray_list
|
||||
unsupported_list -= bf16_list
|
||||
unsupported_list -= gray_list
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,144 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.contrib.mixed_precision as amp
|
||||
from paddle.fluid import core
|
||||
import paddle
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
class AMPTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list)
|
||||
self.fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list)
|
||||
self.gray_list = copy.copy(amp.bf16.amp_lists.gray_list)
|
||||
self.amp_lists_ = None
|
||||
|
||||
def tearDown(self):
|
||||
self.assertEqual(self.amp_lists_.bf16_list, self.bf16_list)
|
||||
self.assertEqual(self.amp_lists_.fp32_list, self.fp32_list)
|
||||
self.assertEqual(self.amp_lists_.gray_list, self.gray_list)
|
||||
|
||||
def test_amp_lists(self):
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16()
|
||||
|
||||
def test_amp_lists_1(self):
|
||||
# 1. w={'exp}, b=None
|
||||
self.bf16_list.add('exp')
|
||||
self.fp32_list.remove('exp')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'})
|
||||
|
||||
def test_amp_lists_2(self):
|
||||
# 2. w={'tanh'}, b=None
|
||||
self.fp32_list.remove('tanh')
|
||||
self.bf16_list.add('tanh')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'})
|
||||
|
||||
def test_amp_lists_3(self):
|
||||
# 3. w={'lstm'}, b=None
|
||||
self.bf16_list.add('lstm')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'})
|
||||
|
||||
def test_amp_lists_4(self):
|
||||
# 4. w=None, b={'elementwise_add'}
|
||||
self.bf16_list.remove('elementwise_add')
|
||||
self.fp32_list.add('elementwise_add')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
|
||||
custom_fp32_list={'elementwise_add'})
|
||||
|
||||
def test_amp_lists_5(self):
|
||||
# 5. w=None, b={'elementwise_add'}
|
||||
self.fp32_list.add('elementwise_add')
|
||||
self.bf16_list.remove('elementwise_add')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
|
||||
custom_fp32_list={'elementwise_add'})
|
||||
|
||||
def test_amp_lists_6(self):
|
||||
# 6. w=None, b={'lstm'}
|
||||
self.fp32_list.add('lstm')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
|
||||
custom_fp32_list={'lstm'})
|
||||
|
||||
def test_amp_lists_7(self):
|
||||
self.fp32_list.add('reshape2')
|
||||
self.gray_list.remove('reshape2')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
|
||||
custom_fp32_list={'reshape2'})
|
||||
|
||||
def test_amp_list_8(self):
|
||||
self.bf16_list.add('reshape2')
|
||||
self.gray_list.remove('reshape2')
|
||||
|
||||
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
|
||||
custom_bf16_list={'reshape2'})
|
||||
|
||||
|
||||
class AMPTest2(unittest.TestCase):
|
||||
def test_amp_lists_(self):
|
||||
# 7. w={'lstm'} b={'lstm'}
|
||||
# raise ValueError
|
||||
self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16,
|
||||
{'lstm'}, {'lstm'})
|
||||
|
||||
def test_find_op_index(self):
|
||||
block = fluid.default_main_program().global_block()
|
||||
op_desc = core.OpDesc()
|
||||
idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc)
|
||||
assert (idx == -1)
|
||||
|
||||
def test_is_in_fp32_varnames(self):
|
||||
block = fluid.default_main_program().global_block()
|
||||
|
||||
var1 = block.create_var(name="X", shape=[3], dtype='float32')
|
||||
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
|
||||
var3 = block.create_var(name="Z", shape=[3], dtype='float32')
|
||||
op1 = block.append_op(
|
||||
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
|
||||
op2 = block.append_op(
|
||||
type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]})
|
||||
amp_lists_1 = amp.AutoMixedPrecisionListsBF16(
|
||||
custom_fp32_varnames={'X'})
|
||||
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1)
|
||||
amp_lists_2 = amp.AutoMixedPrecisionListsBF16(
|
||||
custom_fp32_varnames={'Y'})
|
||||
assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2)
|
||||
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)
|
||||
|
||||
def test_find_true_post_op(self):
|
||||
|
||||
block = fluid.default_main_program().global_block()
|
||||
|
||||
var1 = block.create_var(name="X", shape=[3], dtype='float32')
|
||||
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
|
||||
var3 = block.create_var(name="Z", shape=[3], dtype='float32')
|
||||
op1 = block.append_op(
|
||||
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
|
||||
op2 = block.append_op(
|
||||
type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]})
|
||||
res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
|
||||
assert (res == [op2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import contextlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.static.amp as amp
|
||||
from paddle.fluid import core
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
@unittest.skipIf(not core.supports_bfloat16(),
|
||||
"place does not support BF16 evaluation")
|
||||
class TestModelCastBF16(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.seed = 111
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
pass
|
||||
|
||||
@contextlib.contextmanager
|
||||
def static_graph(self):
|
||||
with self.scope_prog_guard():
|
||||
paddle.seed(self.seed)
|
||||
paddle.framework.random._manual_program_seed(self.seed)
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def scope_prog_guard(self):
|
||||
prog = fluid.Program()
|
||||
startup_prog = fluid.Program()
|
||||
scope = fluid.core.Scope()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(prog, startup_prog):
|
||||
yield
|
||||
|
||||
def get_static_graph_result(self, feed, fetch_list, amp_fun,
|
||||
with_lod=False):
|
||||
exe = fluid.Executor(core.CPUPlace())
|
||||
exe.run(fluid.default_startup_program())
|
||||
prog = fluid.default_main_program()
|
||||
if amp_fun is not None:
|
||||
amp_fun(prog)
|
||||
return exe.run(prog,
|
||||
feed=feed,
|
||||
fetch_list=fetch_list,
|
||||
return_numpy=(not with_lod))
|
||||
|
||||
def test_graph_rewrite(self):
|
||||
size = 3
|
||||
n = np.ones([size, size], dtype='float32') * 3.2
|
||||
nn = np.ones([size, size], dtype='float32') * -2.7
|
||||
|
||||
n_bf16 = amp.convert_float_to_uint16(n)
|
||||
nn_bf16 = amp.convert_float_to_uint16(nn)
|
||||
|
||||
with self.static_graph():
|
||||
t_bf16 = layers.data(
|
||||
name='t_bf16', shape=[size, size], dtype=np.uint16)
|
||||
tt_bf16 = layers.data(
|
||||
name='tt_bf16', shape=[size, size], dtype=np.uint16)
|
||||
t = layers.data(name='t', shape=[size, size], dtype='float32')
|
||||
tt = layers.data(name='tt', shape=[size, size], dtype='float32')
|
||||
|
||||
ret = layers.elementwise_add(t, tt)
|
||||
ret = layers.elementwise_mul(ret, t)
|
||||
ret = layers.reshape(ret, [0, 0])
|
||||
|
||||
with amp.bf16_guard():
|
||||
ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16)
|
||||
ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16)
|
||||
ret_bf16 = layers.reshape(ret_bf16, [0, 0])
|
||||
|
||||
with amp.bf16_guard():
|
||||
ret_fp32bf16 = layers.elementwise_add(t, tt)
|
||||
ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t)
|
||||
ret_fp32bf16 = layers.reshape(ret_fp32bf16, [0, 0])
|
||||
|
||||
static_ret_bf16, static_ret, ret_fp32bf16 = self.get_static_graph_result(
|
||||
feed={
|
||||
't': n,
|
||||
'tt': nn,
|
||||
't_bf16': n_bf16,
|
||||
'tt_bf16': nn_bf16,
|
||||
},
|
||||
fetch_list=[ret_bf16, ret, ret_fp32bf16],
|
||||
amp_fun=lambda prog: amp.rewrite_program_bf16(prog, use_bf16_guard=True))
|
||||
|
||||
self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2))
|
||||
self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2))
|
||||
|
||||
with self.static_graph():
|
||||
t = layers.data(name='t', shape=[size, size], dtype='float32')
|
||||
tt = layers.data(name='tt', shape=[size, size], dtype='float32')
|
||||
|
||||
with amp.bf16_guard():
|
||||
ret = layers.elementwise_add(t, tt)
|
||||
ret = layers.reshape(ret, [0, 0], act='elu')
|
||||
ret = layers.elementwise_mul(ret, t)
|
||||
ret = layers.elementwise_add(ret, tt)
|
||||
|
||||
static_ret_bf16 = \
|
||||
self.get_static_graph_result(
|
||||
feed={'t': n, 'tt': nn},
|
||||
fetch_list=[ret],
|
||||
amp_fun=lambda prog: amp.rewrite_program_bf16(
|
||||
prog,
|
||||
amp.AutoMixedPrecisionListsBF16(
|
||||
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
|
||||
use_bf16_guard=True
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
static_ret_bf16, np.ones(
|
||||
[size, size], dtype='float32') * -1.1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue