You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/contrib/tests/test_bf16_utils.py

145 lines
5.1 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision as amp
from paddle.fluid import core
import paddle
paddle.enable_static()
class AMPTest(unittest.TestCase):
def setUp(self):
self.bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list)
self.fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list)
self.gray_list = copy.copy(amp.bf16.amp_lists.gray_list)
self.amp_lists_ = None
def tearDown(self):
self.assertEqual(self.amp_lists_.bf16_list, self.bf16_list)
self.assertEqual(self.amp_lists_.fp32_list, self.fp32_list)
self.assertEqual(self.amp_lists_.gray_list, self.gray_list)
def test_amp_lists(self):
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16()
def test_amp_lists_1(self):
# 1. w={'exp}, b=None
self.bf16_list.add('exp')
self.fp32_list.remove('exp')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'})
def test_amp_lists_2(self):
# 2. w={'tanh'}, b=None
self.fp32_list.remove('tanh')
self.bf16_list.add('tanh')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'})
def test_amp_lists_3(self):
# 3. w={'lstm'}, b=None
self.bf16_list.add('lstm')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'})
def test_amp_lists_4(self):
# 4. w=None, b={'elementwise_add'}
self.bf16_list.remove('elementwise_add')
self.fp32_list.add('elementwise_add')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
def test_amp_lists_5(self):
# 5. w=None, b={'elementwise_add'}
self.fp32_list.add('elementwise_add')
self.bf16_list.remove('elementwise_add')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
def test_amp_lists_6(self):
# 6. w=None, b={'lstm'}
self.fp32_list.add('lstm')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'lstm'})
def test_amp_lists_7(self):
self.fp32_list.add('reshape2')
self.gray_list.remove('reshape2')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'reshape2'})
def test_amp_list_8(self):
self.bf16_list.add('reshape2')
self.gray_list.remove('reshape2')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_bf16_list={'reshape2'})
class AMPTest2(unittest.TestCase):
def test_amp_lists_(self):
# 7. w={'lstm'} b={'lstm'}
# raise ValueError
self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16,
{'lstm'}, {'lstm'})
def test_find_op_index(self):
block = fluid.default_main_program().global_block()
op_desc = core.OpDesc()
idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc)
assert (idx == -1)
def test_is_in_fp32_varnames(self):
block = fluid.default_main_program().global_block()
var1 = block.create_var(name="X", shape=[3], dtype='float32')
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
var3 = block.create_var(name="Z", shape=[3], dtype='float32')
op1 = block.append_op(
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
op2 = block.append_op(
type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]})
amp_lists_1 = amp.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'X'})
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1)
amp_lists_2 = amp.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'Y'})
assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2)
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)
def test_find_true_post_op(self):
block = fluid.default_main_program().global_block()
var1 = block.create_var(name="X", shape=[3], dtype='float32')
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
var3 = block.create_var(name="Z", shape=[3], dtype='float32')
op1 = block.append_op(
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
op2 = block.append_op(
type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]})
res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
assert (res == [op2])
if __name__ == '__main__':
unittest.main()