Feature/Enable Auto-Mixed-Precision in dynamic graph (#24903)
* add auto_cast, test=develop * add loss scaler, test=develop * add comments, test=develop * refine code, test=develop * refine code, test=develop * do not set flags automatically, test=develop * fix custom op bug, test=develop * add more test, test=develop * refine enable logic, test=develop * enable amp test with GPU, test=develop * add unittest * add test for found_inf * follow comments * follow comments * remove global variable, use singleton * add some notes * update comments * update comments * update comments * add use_dynamic_loss_scaling argument * refine found_inf * refine found_infrevert-24895-update_cub
parent
838e36e9ed
commit
2d95280e1f
@ -0,0 +1,169 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/imperative/amp_auto_cast.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "paddle/fluid/imperative/layer.h"
|
||||
#include "paddle/fluid/imperative/tracer.h"
|
||||
#include "paddle/fluid/imperative/variable_wrapper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
AmpOperators::AmpOperators()
|
||||
: allow_ops_(new std::unordered_set<std::string>()),
|
||||
block_ops_(new std::unordered_set<std::string>()) {}
|
||||
AmpOperators::~AmpOperators() {}
|
||||
|
||||
AmpOperators& AmpOperators::Instance() {
|
||||
static AmpOperators instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetAllowOps() {
|
||||
return allow_ops_;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetBlockOps() {
|
||||
return block_ops_;
|
||||
}
|
||||
|
||||
inline std::string GetDtypeStr(
|
||||
const std::shared_ptr<imperative::VarBase>& var) {
|
||||
return framework::DataTypeToString(var->DataType());
|
||||
}
|
||||
|
||||
inline bool NeedCast(const std::shared_ptr<VarBase>& var) {
|
||||
if (!platform::is_gpu_place(var->Place())) {
|
||||
return false;
|
||||
}
|
||||
if (var->DataType() == framework::proto::VarType::FP32 ||
|
||||
var->DataType() == framework::proto::VarType::FP16) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Trace a cast op, so if a var is casted from fp32 to fp16, then the grad
|
||||
// var will be cast back from fp16 to fp32 during backward phase.
|
||||
static inline std::shared_ptr<imperative::VarBase> CastToType(
|
||||
const std::shared_ptr<VarBase>& var,
|
||||
const framework::proto::VarType::Type dst_type) {
|
||||
const auto& tracer = imperative::GetCurrentTracer();
|
||||
imperative::NameVarBaseMap ins = {{"X", {var}}};
|
||||
framework::AttributeMap attrs = {{"in_dtype", var->DataType()},
|
||||
{"out_dtype", dst_type}};
|
||||
auto out = std::shared_ptr<imperative::VarBase>(
|
||||
new imperative::VarBase(tracer->GenerateUniqueName()));
|
||||
imperative::NameVarBaseMap outs = {{"Out", {out}}};
|
||||
|
||||
{
|
||||
AutoCastGuard guard(tracer, false);
|
||||
tracer->TraceOp("cast", ins, outs, std::move(attrs));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static inline std::shared_ptr<imperative::VarBase> CastToFP16(
|
||||
const std::shared_ptr<VarBase>& var) {
|
||||
auto dst_type = framework::proto::VarType::FP16;
|
||||
if (NeedCast(var) && (var->DataType() != dst_type)) {
|
||||
return CastToType(var, dst_type);
|
||||
}
|
||||
return var;
|
||||
}
|
||||
|
||||
static inline std::shared_ptr<imperative::VarBase> CastToFP32(
|
||||
const std::shared_ptr<VarBase>& var) {
|
||||
auto dst_type = framework::proto::VarType::FP32;
|
||||
if (NeedCast(var) && (var->DataType() != dst_type)) {
|
||||
return CastToType(var, dst_type);
|
||||
}
|
||||
return var;
|
||||
}
|
||||
|
||||
static inline framework::proto::VarType::Type GetPromoteType(
|
||||
const NameVarBaseMap& ins) {
|
||||
auto dst_type = framework::proto::VarType::FP16;
|
||||
for (const auto& pair : ins) {
|
||||
for (const auto& var : pair.second) {
|
||||
if (var->DataType() == framework::proto::VarType::FP32) {
|
||||
dst_type = var->DataType();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return dst_type;
|
||||
}
|
||||
|
||||
NameVarBaseMap AutoCastInputs(const std::string& op_type,
|
||||
const NameVarBaseMap& ins) {
|
||||
NameVarBaseMap new_ins = {};
|
||||
if (AmpOperators::Instance().GetAllowOps()->count(op_type)) {
|
||||
for (const auto& pair : ins) {
|
||||
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
|
||||
<< GetDtypeStr(*pair.second.cbegin()) << " to float16";
|
||||
for (const auto& var : pair.second) {
|
||||
auto new_var = CastToFP16(var);
|
||||
new_ins[pair.first].emplace_back(new_var);
|
||||
}
|
||||
}
|
||||
return new_ins;
|
||||
} else if (AmpOperators::Instance().GetBlockOps()->count(op_type)) {
|
||||
for (const auto& pair : ins) {
|
||||
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
|
||||
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
|
||||
for (const auto& var : pair.second) {
|
||||
auto new_var = CastToFP32(var);
|
||||
new_ins[pair.first].emplace_back(new_var);
|
||||
}
|
||||
}
|
||||
return new_ins;
|
||||
} else {
|
||||
auto dst_type = GetPromoteType(ins);
|
||||
|
||||
for (const auto& pair : ins) {
|
||||
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
|
||||
<< GetDtypeStr(*pair.second.cbegin()) << " to "
|
||||
<< framework::DataTypeToString(dst_type);
|
||||
for (const auto& var : pair.second) {
|
||||
// NOTE(zhiqiu): Conv + BN always occur together, we needn't
|
||||
// cast X of batch_norm to FP32, which is produced by conv as FP16 type.
|
||||
if (op_type == "batch_norm" && pair.first == "X" &&
|
||||
dst_type == framework::proto::VarType::FP32) {
|
||||
new_ins[pair.first].emplace_back(var);
|
||||
continue;
|
||||
}
|
||||
auto new_var = dst_type == framework::proto::VarType::FP32
|
||||
? CastToFP32(var)
|
||||
: CastToFP16(var);
|
||||
new_ins[pair.first].emplace_back(new_var);
|
||||
}
|
||||
}
|
||||
return new_ins;
|
||||
}
|
||||
return ins;
|
||||
}
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
@ -0,0 +1,79 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "paddle/fluid/imperative/tracer.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
// Singleton implementation with C++ 11
|
||||
class AmpOperators {
|
||||
public:
|
||||
~AmpOperators();
|
||||
AmpOperators(const AmpOperators& o) = delete;
|
||||
const AmpOperators& operator=(const AmpOperators& o) = delete;
|
||||
|
||||
static AmpOperators& Instance();
|
||||
|
||||
std::shared_ptr<std::unordered_set<std::string>> GetAllowOps();
|
||||
|
||||
std::shared_ptr<std::unordered_set<std::string>> GetBlockOps();
|
||||
|
||||
private:
|
||||
AmpOperators(); // forbid calling default constructor
|
||||
|
||||
// The set of ops that support fp16 calculation and are considered numerically
|
||||
// safe and performance critical. These ops are always converted to fp16.
|
||||
std::shared_ptr<std::unordered_set<std::string>> allow_ops_;
|
||||
|
||||
// The set of ops that support fp16 calculation and are considered numerically
|
||||
// dangerous and whose effects may also be observed in downstream ops.
|
||||
std::shared_ptr<std::unordered_set<std::string>> block_ops_;
|
||||
};
|
||||
|
||||
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
|
||||
class AutoCastGuard {
|
||||
public:
|
||||
AutoCastGuard(std::shared_ptr<Tracer> tracer, bool guard_mode)
|
||||
: tracer_(tracer) {
|
||||
pre_mode_ = tracer_->IsAutoCastEnabled();
|
||||
if (pre_mode_ != guard_mode) {
|
||||
tracer_->SetEnableAutoCast(guard_mode);
|
||||
}
|
||||
}
|
||||
|
||||
~AutoCastGuard() { tracer_->SetEnableAutoCast(pre_mode_); }
|
||||
|
||||
// forbid copy and operator=
|
||||
AutoCastGuard(const AutoCastGuard& guard) = delete;
|
||||
AutoCastGuard& operator=(const AutoCastGuard& guard) = delete;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Tracer> tracer_;
|
||||
bool pre_mode_;
|
||||
};
|
||||
|
||||
NameVarBaseMap AutoCastInputs(const std::string& op_type,
|
||||
const NameVarBaseMap& ins);
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import auto_cast
|
||||
from .auto_cast import *
|
||||
|
||||
from . import loss_scaler
|
||||
from .loss_scaler import *
|
||||
|
||||
__all__ = []
|
||||
__all__ += auto_cast.__all__
|
||||
__all__ += loss_scaler.__all__
|
@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager, wrap_decorator
|
||||
from paddle.fluid import core
|
||||
import contextlib
|
||||
from paddle.fluid.framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, dygraph_only, set_flags, get_flags
|
||||
import warnings
|
||||
import copy
|
||||
|
||||
__all__ = ['amp_guard']
|
||||
|
||||
# The set of ops that support fp16 calculation and are considered numerically-
|
||||
# safe and performance-critical. These ops are always converted to fp16.
|
||||
WHITE_LIST = {
|
||||
'conv2d',
|
||||
'matmul',
|
||||
'mul',
|
||||
}
|
||||
|
||||
# The set of ops that support fp16 calculation and are considered numerically-
|
||||
# dangerous and whose effects may also be observed in downstream ops.
|
||||
BLACK_LIST = {
|
||||
'exp',
|
||||
'square',
|
||||
'log',
|
||||
'mean',
|
||||
'sum',
|
||||
'cos_sim',
|
||||
'softmax',
|
||||
'softmax_with_cross_entropy',
|
||||
'sigmoid_cross_entropy_with_logits',
|
||||
'cross_entropy',
|
||||
'cross_entropy2',
|
||||
}
|
||||
|
||||
AMP_RELATED_FLAGS = [
|
||||
'FLAGS_cudnn_exhaustive_search',
|
||||
'FLAGS_conv_workspace_size_limit',
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent',
|
||||
]
|
||||
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_exhaustive_search': 1,
|
||||
'FLAGS_conv_workspace_size_limit': 1000,
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||
}
|
||||
|
||||
|
||||
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
|
||||
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
|
||||
def _update_list(custom_white_list, custom_black_list):
|
||||
"""
|
||||
Update black and white list according to users' custom list.
|
||||
"""
|
||||
_white_list = copy.copy(WHITE_LIST)
|
||||
_black_list = copy.copy(BLACK_LIST)
|
||||
if custom_white_list and custom_black_list:
|
||||
for op_name in custom_white_list:
|
||||
if op_name in custom_black_list:
|
||||
raise ValueError("Custom white list overlap "
|
||||
"custom black list")
|
||||
if custom_white_list:
|
||||
for op_name in custom_white_list:
|
||||
if op_name in _black_list:
|
||||
_black_list.remove(op_name)
|
||||
_white_list.add(op_name)
|
||||
if custom_black_list:
|
||||
for op_name in custom_black_list:
|
||||
if op_name in _white_list:
|
||||
_white_list.remove(op_name)
|
||||
_black_list.add(op_name)
|
||||
return _white_list, _black_list
|
||||
|
||||
|
||||
@signature_safe_contextmanager
|
||||
@dygraph_only
|
||||
def amp_guard(enable=True, custom_white_list=None, custom_black_list=None):
|
||||
"""
|
||||
:api_attr: imperative
|
||||
|
||||
Create a context which enables auto-mixed-precision(AMP) of operators executed in imperative mode.
|
||||
If enabled, the input data type (float32 or float16) of each operator is decided
|
||||
by autocast algorithm for better performance.
|
||||
|
||||
Commonly, it is used together with `AmpScaler` to achieve Auto-Mixed-Precision in
|
||||
imperative mode.
|
||||
|
||||
Args:
|
||||
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
|
||||
custom_white_list(set|list, optional): The custom white_list.
|
||||
custom_black_list(set|list, optional): The custom black_list.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
|
||||
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
|
||||
with fluid.dygraph.guard():
|
||||
conv2d = fluid.dygraph.Conv2D(3, 2, 3)
|
||||
data = fluid.dygraph.to_variable(data)
|
||||
with fluid.dygraph.amp_guard():
|
||||
conv = conv2d(data)
|
||||
print(conv.dtype) # FP16
|
||||
with fluid.dygraph.amp_guard(enable=False):
|
||||
conv = conv2d(data)
|
||||
print(conv.dtype) # FP32
|
||||
|
||||
"""
|
||||
tracer = _dygraph_tracer()
|
||||
if not tracer:
|
||||
raise ValueError(
|
||||
"current_tracer is None, maybe it is not in imperative mode.")
|
||||
|
||||
if enable and not tracer._expected_place.is_gpu_place():
|
||||
warnings.warn(
|
||||
'amp_guard can only be enabled on CUDAPlace, current place is %s, so it makes no effect.'
|
||||
% tracer._expected_place)
|
||||
enable = False
|
||||
|
||||
# use default white_list and black_list if no custom lists provided
|
||||
_white_list = WHITE_LIST
|
||||
_black_list = BLACK_LIST
|
||||
if custom_white_list or custom_black_list:
|
||||
_white_list, _black_list = _update_list(custom_white_list,
|
||||
custom_black_list)
|
||||
|
||||
if tracer:
|
||||
# enable auto_cast
|
||||
original_enable = tracer._enable_autocast
|
||||
tracer._enable_autocast = enable
|
||||
# set amp op list
|
||||
original_white_list, original_black_list = tracer._get_amp_op_list()
|
||||
tracer._set_amp_op_list(_white_list, _black_list)
|
||||
|
||||
# TODO(zhiqiu) set amp related flags automatically in this guard
|
||||
# Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
|
||||
# batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard.
|
||||
# So, users need to set related flags manually.
|
||||
|
||||
# original_flags = get_flags(AMP_RELATED_FLAGS)
|
||||
# set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
# restore status
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if tracer:
|
||||
tracer._enable_autocast = original_enable
|
||||
tracer._set_amp_op_list(original_white_list, original_black_list)
|
||||
# set_flags(original_flags)
|
@ -0,0 +1,246 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
from paddle.fluid import core
|
||||
from paddle.fluid.dygraph import to_variable
|
||||
from paddle.fluid.framework import _varbase_creator, _dygraph_tracer, dygraph_only
|
||||
from paddle.fluid.data_feeder import check_type
|
||||
from ...wrapped_decorator import signature_safe_contextmanager, wrap_decorator
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['AmpScaler']
|
||||
|
||||
|
||||
class AmpScaler(object):
|
||||
"""
|
||||
:api_attr: imperative
|
||||
|
||||
AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
|
||||
mode. It controls the scaling of loss, helps avoiding numerical overflow.
|
||||
The object of this class has two methods `scale()`, `minimize()`.
|
||||
|
||||
`scale()` is used to multiply the loss by a scale ratio.
|
||||
`minimize()` is similar as `Optimizer.minimize()`, performs parameters updating.
|
||||
|
||||
Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
|
||||
imperative mode.
|
||||
|
||||
Args:
|
||||
enable(bool, optional): Enable loss scaling or not. Default is True.
|
||||
init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
|
||||
incr_ratio(float, optional): The multiplier to use when increasing the loss
|
||||
scaling. Default is 2.0.
|
||||
decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
|
||||
the loss scaling. Default is 0.5.
|
||||
incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
|
||||
steps with finite gradients. Default is 1000.
|
||||
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
|
||||
accumulated steps with nan or inf gradients. Default is 2.
|
||||
use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
|
||||
Returns:
|
||||
An AmpScaler object.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
|
||||
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
|
||||
with fluid.dygraph.guard():
|
||||
model = fluid.dygraph.Conv2D(3, 2, 3)
|
||||
optimizer = fluid.optimizer.SGDOptimizer(
|
||||
learning_rate=0.01, parameter_list=model.parameters())
|
||||
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
|
||||
data = fluid.dygraph.to_variable(data)
|
||||
with fluid.dygraph.amp_guard():
|
||||
conv = model(data)
|
||||
loss = fluid.layers.reduce_mean(conv)
|
||||
scaled = scaler.scale(loss)
|
||||
scaled.backward()
|
||||
scaler.minimize(optimizer, scaled)
|
||||
"""
|
||||
|
||||
@dygraph_only
|
||||
def __init__(self,
|
||||
enable=True,
|
||||
init_loss_scaling=2.**15,
|
||||
incr_ratio=2.0,
|
||||
decr_ratio=0.5,
|
||||
incr_every_n_steps=1000,
|
||||
decr_every_n_nan_or_inf=1,
|
||||
use_dynamic_loss_scaling=True):
|
||||
|
||||
tracer = _dygraph_tracer()
|
||||
if not tracer:
|
||||
raise ValueError(
|
||||
"current_tracer is None, maybe it is not in imperative mode.")
|
||||
|
||||
if enable and not tracer._expected_place.is_gpu_place():
|
||||
warnings.warn(
|
||||
'AmpScaler can only be enabled on CUDAPlace, current place is %s, so it makes no effect.'
|
||||
% tracer._expected_place)
|
||||
enable = False
|
||||
|
||||
self._enable = enable
|
||||
|
||||
if self._enable:
|
||||
assert incr_ratio > 1.0, "The incr_ratio must be > 1.0."
|
||||
assert decr_ratio < 1.0, "The decr_ratio must be < 1.0."
|
||||
|
||||
self._init_loss_scaling = init_loss_scaling
|
||||
self._incr_ratio = incr_ratio
|
||||
self._decr_ratio = decr_ratio
|
||||
self._incr_every_n_steps = incr_every_n_steps
|
||||
self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
|
||||
self._incr_count = 0
|
||||
self._decr_count = 0
|
||||
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
|
||||
|
||||
self._found_inf = to_variable(np.array([0]).astype(np.bool))
|
||||
self._scale = to_variable(
|
||||
np.array([self._init_loss_scaling]).astype(np.float32))
|
||||
self._cache_founf_inf = None
|
||||
|
||||
def scale(self, var):
|
||||
"""
|
||||
Multiplies a variable(Tensor) by the scale factor and returns scaled outputs.
|
||||
If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified.
|
||||
|
||||
Args:
|
||||
var (Variable): The variable to scale.
|
||||
Returns:
|
||||
The scaled variable or original variable.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
|
||||
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
|
||||
with fluid.dygraph.guard():
|
||||
model = fluid.dygraph.Conv2D(3, 2, 3)
|
||||
optimizer = fluid.optimizer.SGDOptimizer(
|
||||
learning_rate=0.01, parameter_list=model.parameters())
|
||||
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
|
||||
data = fluid.dygraph.to_variable(data)
|
||||
with fluid.dygraph.amp_guard():
|
||||
conv = model(data)
|
||||
loss = fluid.layers.reduce_mean(conv)
|
||||
scaled = scaler.scale(loss)
|
||||
scaled.backward()
|
||||
scaler.minimize(optimizer, scaled)
|
||||
"""
|
||||
check_type(var, "var", core.VarBase, 'AmpScaler.scale()')
|
||||
|
||||
if not self._enable:
|
||||
return var
|
||||
|
||||
return var * self._scale
|
||||
|
||||
def minimize(self, optimizer, *args, **kwargs):
|
||||
"""
|
||||
This function is similar as `Optimizer.minimize()`, which performs parameters updating.
|
||||
|
||||
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
|
||||
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
|
||||
|
||||
Finally, the loss scaling ratio is updated.
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): The optimizer used to update parameters.
|
||||
args: Arguments, which will be forward to `optimizer.minimize()`.
|
||||
kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
|
||||
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
|
||||
with fluid.dygraph.guard():
|
||||
model = fluid.dygraph.Conv2D(3, 2, 3)
|
||||
optimizer = fluid.optimizer.SGDOptimizer(
|
||||
learning_rate=0.01, parameter_list=model.parameters())
|
||||
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
|
||||
data = fluid.dygraph.to_variable(data)
|
||||
with fluid.dygraph.amp_guard():
|
||||
conv = model(data)
|
||||
loss = fluid.layers.reduce_mean(conv)
|
||||
scaled = scaler.scale(loss)
|
||||
scaled.backward()
|
||||
scaler.minimize(optimizer, scaled)
|
||||
"""
|
||||
if not self._enable:
|
||||
return optimizer.minimize(*args, **kwargs)
|
||||
|
||||
# unscale the grad
|
||||
self._unscale(optimizer)
|
||||
|
||||
optimize_ops, params_grads = (None, None)
|
||||
|
||||
if self._found_inf:
|
||||
self._cache_founf_inf = True
|
||||
else:
|
||||
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
|
||||
self._cache_founf_inf = False
|
||||
|
||||
if self._use_dynamic_loss_scaling:
|
||||
# uopdate the scale
|
||||
self._update()
|
||||
|
||||
return optimize_ops, params_grads
|
||||
|
||||
def _unscale(self, optimizer):
|
||||
if not self._enable:
|
||||
return
|
||||
inv_scale = 1.0 / self._scale
|
||||
param_grads = [
|
||||
param._grad_ivar() for param in optimizer._parameter_list
|
||||
if param._grad_ivar() is not None
|
||||
]
|
||||
core.ops.amp_check_finite_and_scale(param_grads, inv_scale, param_grads,
|
||||
self._found_inf)
|
||||
|
||||
def _update(self):
|
||||
"""
|
||||
Updates the loss_scaling.
|
||||
"""
|
||||
if not self._enable:
|
||||
return
|
||||
|
||||
if self._cache_founf_inf:
|
||||
self._incr_count = 0
|
||||
self._decr_count = self._decr_count + 1
|
||||
if self._decr_count == self._decr_every_n_nan_or_inf:
|
||||
print(
|
||||
'Found inf or nan, current scale is: {}, decrease to: {}*{}'.
|
||||
format(
|
||||
float(self._scale),
|
||||
float(self._scale), float(self._decr_ratio)))
|
||||
self._scale = self._scale * self._decr_ratio
|
||||
self._decr_count = 0
|
||||
else:
|
||||
self._decr_count = 0
|
||||
self._incr_count = self._incr_count + 1
|
||||
if self._incr_count == self._incr_every_n_steps:
|
||||
self._scale = self._scale * self._incr_ratio
|
||||
self._incr_count = 0
|
||||
|
||||
return
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue