You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/dygraph/amp/auto_cast.py

167 lines
5.9 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager, wrap_decorator
from paddle.fluid import core
import contextlib
from paddle.fluid.framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, dygraph_only, set_flags, get_flags
import warnings
import copy
__all__ = ['amp_guard']
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
WHITE_LIST = {
'conv2d',
'matmul',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
BLACK_LIST = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'cross_entropy',
'cross_entropy2',
}
AMP_RELATED_FLAGS = [
'FLAGS_cudnn_exhaustive_search',
'FLAGS_conv_workspace_size_limit',
'FLAGS_cudnn_batchnorm_spatial_persistent',
]
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 1000,
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
def _update_list(custom_white_list, custom_black_list):
"""
Update black and white list according to users' custom list.
"""
_white_list = copy.copy(WHITE_LIST)
_black_list = copy.copy(BLACK_LIST)
if custom_white_list and custom_black_list:
for op_name in custom_white_list:
if op_name in custom_black_list:
raise ValueError("Custom white list overlap "
"custom black list")
if custom_white_list:
for op_name in custom_white_list:
if op_name in _black_list:
_black_list.remove(op_name)
_white_list.add(op_name)
if custom_black_list:
for op_name in custom_black_list:
if op_name in _white_list:
_white_list.remove(op_name)
_black_list.add(op_name)
return _white_list, _black_list
@signature_safe_contextmanager
@dygraph_only
def amp_guard(enable=True, custom_white_list=None, custom_black_list=None):
"""
:api_attr: imperative
Create a context which enables auto-mixed-precision(AMP) of operators executed in imperative mode.
If enabled, the input data type (float32 or float16) of each operator is decided
by autocast algorithm for better performance.
Commonly, it is used together with `AmpScaler` to achieve Auto-Mixed-Precision in
imperative mode.
Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
custom_white_list(set|list, optional): The custom white_list.
custom_black_list(set|list, optional): The custom black_list.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
conv2d = fluid.dygraph.Conv2D(3, 2, 3)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = conv2d(data)
print(conv.dtype) # FP16
with fluid.dygraph.amp_guard(enable=False):
conv = conv2d(data)
print(conv.dtype) # FP32
"""
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode.")
if enable and not tracer._expected_place.is_gpu_place():
warnings.warn(
'amp_guard can only be enabled on CUDAPlace, current place is %s, so it makes no effect.'
% tracer._expected_place)
enable = False
# use default white_list and black_list if no custom lists provided
_white_list = WHITE_LIST
_black_list = BLACK_LIST
if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(custom_white_list,
custom_black_list)
if tracer:
# enable auto_cast
original_enable = tracer._enable_autocast
tracer._enable_autocast = enable
# set amp op list
original_white_list, original_black_list = tracer._get_amp_op_list()
tracer._set_amp_op_list(_white_list, _black_list)
# TODO(zhiqiu) set amp related flags automatically in this guard
# Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
# batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard.
# So, users need to set related flags manually.
# original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING)
# restore status
try:
yield
finally:
if tracer:
tracer._enable_autocast = original_enable
tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags)