Support pure fp16 training for AMP API. (#29544)

* add cast ops before and after unsupported fp16 ops.

* Keep partial net in FP32 pattern.

* Support check_finite_and_unscale and update_loss_scaling for FP16 calculation mode.

* Add fp16 support for adam op.

* add multi precision attr for adam.

* Fix the bug of test_multi_precision_fp16_train UT.

* Code format for CI.

* Fix the redefine error about MPTypeTrait on windows.

* fix bugs of the _create_accumulators func in Momentum.

* fix bug when inserting post cast op.

* Add the update_loss_scaling op in allow_set of UnusedVarCheck.

* Update for ci coverage.

* Add some doc for OptimizerWithMixedPrecision.

* Fix the code style.

* Imporve the doc of `amp_init`.

* Change for fp16 testing if users have the infer program defined in separate way.
revert-31562-mean
Zhen Wang 4 years ago committed by GitHub
parent 789743e190
commit 7f7dfccf20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -73,7 +73,8 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad", // 2
"data_norm", // 0
"data_norm_grad", // 0);
"data_norm_grad", // 0
"update_loss_scaling", // 0
});
return *allow_set;
}

@ -15,6 +15,8 @@ limitations under the License. */
#include <cuda.h>
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
@ -25,15 +27,16 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
*found_inf = false;
}
template <typename T>
__global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num,
template <typename T, typename MT>
__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num,
bool* found_inf, T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
T val = in[idx] * (*scale);
out[idx] = val;
if (!isfinite(val)) {
MT val = static_cast<MT>(in[idx]) * (*scale);
T narrow_val = static_cast<T>(val);
out[idx] = narrow_val;
if (!isfinite(narrow_val)) {
*found_inf = true;
}
}
@ -41,6 +44,8 @@ __global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num,
template <typename T>
class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
@ -49,14 +54,15 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
const T* scale_data = scale->data<T>();
const MPDType* scale_data = scale->data<MPDType>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
framework::Tensor inverse_scale =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({1}, dev_ctx);
T* inverse_scale_v = inverse_scale.template data<T>();
ctx.AllocateTmpTensor<MPDType, platform::CUDADeviceContext>({1},
dev_ctx);
MPDType* inverse_scale_v = inverse_scale.template data<MPDType>();
InverseAndMemset<T><<<1, 1, 0, dev_ctx.stream()>>>(
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
scale_data, inverse_scale_v, found_inf_data);
for (size_t i = 0; i < xs.size(); ++i) {
@ -69,7 +75,7 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
int block = 1024;
int grid = (num + block - 1) / block;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T><<<grid, block, 0, dev_ctx.stream()>>>(
CheckFiniteAndUnscale<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, inverse_scale_v, num, found_inf_data, out_data);
VLOG(3) << "finish kernel";
}
@ -79,6 +85,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleGpuKernel<float>,
ops::CheckFiniteAndUnscaleGpuKernel<double>);
ops::CheckFiniteAndUnscaleGpuKernel<double>,
ops::CheckFiniteAndUnscaleGpuKernel<plat::float16>);

@ -0,0 +1,37 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
namespace details {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};
} // namespace details
} // namespace operators
} // namespace paddle

@ -54,8 +54,7 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "PrevLossScaling"),
ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
@ -107,6 +106,9 @@ class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker {
"the received is %f",
decr_ratio));
});
AddAttr<bool>("stop_update",
"Stop updating loss scaling, and just zero inputs.")
.SetDefault(false);
AddComment(R"DOC(
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.

@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
@ -83,8 +84,10 @@ class LazyZeros<platform::CUDADeviceContext, T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
using GPU = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(update_loss_scaling,
ops::UpdateLossScalingKernel<GPU, float>,
ops::UpdateLossScalingKernel<GPU, double>);
ops::UpdateLossScalingKernel<GPU, double>,
ops::UpdateLossScalingKernel<GPU, plat::float16>);

@ -17,6 +17,7 @@
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
@ -79,30 +80,38 @@ class LazyZeros {
template <typename DeviceContext, typename T>
class UpdateLossScalingKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
const auto* found_inf = ctx.Input<Tensor>("FoundInfinite");
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
const bool* found_inf_data = found_inf->data<bool>();
LazyZeros<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
const bool stop_update = ctx.Attr<bool>("stop_update");
if (stop_update) {
return;
}
const auto* pre_loss_scaling = ctx.Input<Tensor>("PrevLossScaling");
const auto* good_in = ctx.Input<Tensor>("InGoodSteps");
const auto* bad_in = ctx.Input<Tensor>("InBadSteps");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* updated_loss_scaling = ctx.Output<Tensor>("LossScaling");
auto* good_out = ctx.Output<Tensor>("OutGoodSteps");
auto* bad_out = ctx.Output<Tensor>("OutBadSteps");
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
const bool* found_inf_data = found_inf->data<bool>();
const T* pre_loss_scaling_data = pre_loss_scaling->data<T>();
const MPDType* pre_loss_scaling_data = pre_loss_scaling->data<MPDType>();
const int* good_in_data = good_in->data<int>();
const int* bad_in_data = bad_in->data<int>();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
T* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<T>(dev_ctx.GetPlace());
MPDType* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<MPDType>(dev_ctx.GetPlace());
int* good_out_data = good_out->mutable_data<int>(dev_ctx.GetPlace());
int* bad_out_data = bad_out->mutable_data<int>(dev_ctx.GetPlace());
@ -111,11 +120,10 @@ class UpdateLossScalingKernel : public framework::OpKernel<T> {
ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio");
UpdateLossScalingFunctor<DeviceContext, T>{}(
UpdateLossScalingFunctor<DeviceContext, MPDType>{}(
dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data,
bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data);
LazyZeros<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
}
};

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adam_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
@ -150,12 +151,17 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"as beta2, this has a higher priority than attr(beta2), the "
"shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment");
AddOutput("Moment2Out", "(Tensor) Output second moment");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<float>("beta1",
"(float, default 0.9) "
@ -183,6 +189,10 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"inner_op_parallelism is larger then 0, sparse update "
"will run in multithread mode")
.SetDefault(1000);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddComment(R"DOC(
Adam Optimizer.
@ -213,3 +223,13 @@ REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(
adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(adam)
.AddCheckpoint(
R"ROC(
Upgrade adam add 1 attribute [multi_precision].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"multi_precision",
"(bool) Whether to use multi-precision during weight updating.",
false));

File diff suppressed because it is too large Load Diff

@ -191,26 +191,28 @@ class AdamFunctor<T, CPUAdam> {
}
};
template <typename T, typename Flavour>
template <typename T, typename Flavour, typename MT = T>
class SparseAdamFunctor;
template <typename T>
class SparseAdamFunctor<T, GPUAdam> {
template <typename T, typename MT>
class SparseAdamFunctor<T, GPUAdam, MT> {
private:
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
MT beta1_;
MT beta2_;
MT epsilon_;
const MT* beta1_pow_;
const MT* beta2_pow_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
MT* moment2_out_;
const MT* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const MT* master_param_;
MT* master_param_out_;
const int64_t* rows_;
int64_t row_numel_;
@ -218,10 +220,11 @@ class SparseAdamFunctor<T, GPUAdam> {
bool lazy_mode_;
public:
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
SparseAdamFunctor(MT beta1, MT beta2, MT epsilon, const MT* beta1_pow,
const MT* beta2_pow, const MT* mom1, MT* mom1_out,
const MT* mom2, MT* mom2_out, const MT* lr, const T* grad,
const T* param, T* param_out, const MT* master_param,
MT* master_param_out, const int64_t* rows,
int64_t row_numel, int64_t row_count, bool lazy_mode)
: beta1_(beta1),
beta2_(beta2),
@ -236,31 +239,38 @@ class SparseAdamFunctor<T, GPUAdam> {
grad_(grad),
param_(param),
param_out_(param_out),
master_param_(master_param),
master_param_out_(master_param_out),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count),
lazy_mode_(lazy_mode) {}
inline HOSTDEVICE void adam_update(size_t i, T g) const {
inline HOSTDEVICE void adam_update(size_t i, MT g) const {
// The following code is the same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
MT mom1 = moment1_[i];
MT mom2 = moment2_[i];
MT lr = *lr_;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
param_out_[i] = static_cast<T>(p);
if (master_param_out_) {
master_param_out_[i] = p;
}
}
inline HOSTDEVICE void operator()(size_t i) const {
@ -269,14 +279,16 @@ class SparseAdamFunctor<T, GPUAdam> {
if (lazy_mode_ && row_idx < 0) {
return;
} else {
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
MT g = row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_])
: static_cast<MT>(0);
adam_update(i, g);
}
}
};
template <typename T>
class SparseAdamFunctor<T, CPUAdam> {
class SparseAdamFunctor<T, CPUAdam, T> {
private:
T beta1_;
T beta2_;

@ -115,7 +115,8 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_VERSION(momentum)
.AddCheckpoint(
R"ROC(
Upgrade momentum add 2 attributes [regularization_method, regularization_coeff].
Upgrade momentum add 4 attributes [regularization_method, regularization_coeff,
multi_precision, rescale_grad].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("MasterParam", "FP32 master weight for AMP.")

@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/float16.h"
@ -32,17 +33,6 @@ struct UseNesterov;
namespace details {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};
template <typename T>
struct CPUDenseUpdater {
template <typename G>

@ -15,6 +15,7 @@
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable
from paddle.fluid import core
__all__ = ['check_finite_and_unscale', 'update_loss_scaling']
@ -35,7 +36,7 @@ def check_finite_and_unscale(x, scale, name=None):
"""
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x:
check_variable_and_dtype(e, "x", ['float32', 'float64'],
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
helper = LayerHelper("check_finite_and_unscale", **locals())
@ -58,6 +59,7 @@ def update_loss_scaling(x,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
stop_update=False,
name=None):
"""
Update loss scaling according to overall gradients. If all gradients is
@ -90,9 +92,13 @@ def update_loss_scaling(x,
['float32', 'float64'], "update_loss_scaling")
check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x:
check_variable_and_dtype(e, "x", ['float32', 'float64'],
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
if e.dtype == core.VarDesc.VarType.FP16:
assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals())
@ -116,6 +122,7 @@ def update_loss_scaling(x,
'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf,
'incr_ratio': incr_ratio,
'decr_ratio': decr_ratio,
'stop_update': stop_update
}
helper.append_op(

File diff suppressed because it is too large Load Diff

@ -38,6 +38,7 @@ class AutoMixedPrecisionLists(object):
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
@ -64,6 +65,7 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.black_list.add(op_name)
self.unsupported_list.add(op_name)
# The three sets listed below are changed dynamiclly. They don't contain all
@ -141,10 +143,10 @@ gray_list = {
'cast',
'fused_bn_add_activation',
}
'''
# The set of ops that don't support fp16 calculation
unsupported_fp16_list = {
# from python/paddle/fluid/layers/io.py
# from python/paddle/fluid/layers/io.py
'send',
'send_barrier',
'recv',
@ -153,8 +155,8 @@ unsupported_fp16_list = {
'create_double_buffer_reader',
'read',
'load',
# from python/paddle/fluid/control_flow.py
# from python/paddle/fluid/control_flow.py
'increment',
'less_than',
'less_equal',
@ -174,7 +176,6 @@ unsupported_fp16_list = {
'while',
'ifelse',
'is_empty',
'lstm',
'cudnn_lstm',
'lstmp',
@ -275,7 +276,6 @@ unsupported_fp16_list = {
'pixel_shuffle',
'fsp',
'cvm',
'affine_channel',
'roi_pool',
'roi_align',
@ -283,6 +283,4 @@ unsupported_fp16_list = {
'generate_proposals',
'generate_proposal_labels',
'generate_mask_labels',
}
'''

File diff suppressed because it is too large Load Diff

@ -19,8 +19,7 @@ import paddle.fluid as fluid
import contextlib
import unittest
import numpy as np
from paddle.static.amp import cast_model_to_fp16
from paddle.static.amp import cast_parameters_to_fp16
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
paddle.enable_static()
@ -65,38 +64,19 @@ def resnet_cifar10(input, depth=32):
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
with paddle.static.amp.fp16_guard():
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool
def compile(program, loss_name=None):
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000
build_strategy.fuse_bn_act_ops = True
build_strategy.fuse_elewise_add_act_ops = True
build_strategy.fuse_bn_add_act_ops = True
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compiled_program
def train(use_pure_fp16=True, use_nesterov=False):
def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
classdim = 10
data_shape = [3, 32, 32]
BATCH_SIZE = 128
BATCH_SIZE = 32
PASS_NUM = 1
train_program = fluid.Program()
@ -107,28 +87,35 @@ def train(use_pure_fp16=True, use_nesterov=False):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net = resnet_cifar10(images, 32)
net = resnet_cifar10(images)
logits = fluid.layers.fc(input=net, size=classdim, act="softmax")
if use_pure_fp16:
cast_model_to_fp16(fluid.default_main_program())
logits_fp32 = fluid.layers.cast(x=logits, dtype="float32")
else:
logits_fp32 = logits
cost = fluid.layers.softmax_with_cross_entropy(
logits_fp32, label, return_softmax=False)
logits, label, return_softmax=False)
sum_cost = fluid.layers.reduce_sum(cost)
# Test program
test_program = train_program.clone(for_test=True)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
momentum=0.9,
use_nesterov=use_nesterov,
weight_decay=fluid.regularizer.L2Decay(1e-4),
multi_precision=use_pure_fp16,
rescale_grad=1.0 / BATCH_SIZE)
if use_adam:
optimizer = paddle.optimizer.Adam(
learning_rate=0.001,
epsilon=1e-8,
weight_decay=0.0,
multi_precision=True)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
momentum=0.9,
use_nesterov=use_nesterov,
weight_decay=fluid.regularizer.L2Decay(1e-4),
multi_precision=use_pure_fp16)
if use_pure_fp16:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True)
optimizer.minimize(sum_cost)
@ -146,13 +133,13 @@ def train(use_pure_fp16=True, use_nesterov=False):
def train_loop(main_program):
exe.run(startup_prog)
if use_pure_fp16:
cast_parameters_to_fp16(place, train_program, fluid.global_scope())
compiled_program = compile(train_program, sum_cost.name)
optimizer.amp_init(
place, test_program=test_program, use_fp16_test=True)
loss = 0.0
for pass_id in range(PASS_NUM):
train_loss_list = []
for batch_id, data in enumerate(train_reader()):
loss, = exe.run(compiled_program,
loss, = exe.run(train_program,
feed=feeder.feed(data),
fetch_list=[sum_cost])
loss_v = loss[0] if isinstance(loss, np.ndarray) else loss
@ -182,18 +169,25 @@ class TestImageMultiPrecision(unittest.TestCase):
if not fluid.core.is_compiled_with_cuda():
return
def do_test(use_nesterov=False):
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
def do_test(use_nesterov=False, use_adam=False):
if use_adam:
suffix = "use Adam"
else:
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
with self.scope_prog_guard():
print("-----------------FP16 Train {}-----------------".format(
suffix))
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True, use_nesterov=use_nesterov)
use_pure_fp16=True,
use_nesterov=use_nesterov,
use_adam=use_adam)
with self.scope_prog_guard():
print("-----------------FP32 Train {}-----------------".format(
suffix))
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False, use_nesterov=use_nesterov)
use_pure_fp16=False,
use_nesterov=use_nesterov,
use_adam=use_adam)
self.assertTrue(
np.allclose(
@ -214,6 +208,7 @@ class TestImageMultiPrecision(unittest.TestCase):
do_test(use_nesterov=False)
do_test(use_nesterov=True)
do_test(use_adam=True)
@contextlib.contextmanager
def scope_prog_guard(self):
@ -260,7 +255,7 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
op._set_attr('out_dtype', fluid.core.VarDesc.VarType.FP32)
op._set_attr('dtype', fluid.core.VarDesc.VarType.FP32)
cast_model_to_fp16(main_prog)
cast_model_to_fp16(main_prog, use_fp16_guard=False)
def test_non_iterable_dataloader(self):
self.decorate_with_data_loader()

@ -16,6 +16,10 @@ from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable
from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
import warnings
from ..fluid.dygraph import base as imperative_base
import paddle
@ -79,6 +83,7 @@ class Adam(Optimizer):
gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
@ -135,6 +140,7 @@ class Adam(Optimizer):
weight_decay=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
assert learning_rate is not None
assert beta1 is not None
@ -157,28 +163,90 @@ class Adam(Optimizer):
self._beta2 = beta2
self._epsilon = epsilon
self._lazy_mode = lazy_mode
self._multi_precision = multi_precision
self._master_weights = {}
def _create_master_weight(self, param):
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
target_name = target_param.name
if (name not in self._accumulators or
target_name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
format(name, target_name))
return self._accumulators[name][target_name]
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
# Create accumulator tensors for first and second moments
for p in parameters:
self._add_accumulator(self._moment1_acc_str, p)
self._add_accumulator(self._moment2_acc_str, p)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_moments_pows(p)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
@ -191,6 +259,10 @@ class Adam(Optimizer):
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
lr = self._create_param_lr(param_and_grad)
# create the adam optimize op
@ -227,7 +299,8 @@ class Adam(Optimizer):
attrs = {
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000
"min_row_size_to_use_multithread": 1000,
"multi_precision": find_master
}
if isinstance(self._beta1, Variable):
@ -239,6 +312,10 @@ class Adam(Optimizer):
else:
attrs['beta2'] = self._beta2
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
adam_op = block.append_op(
type=self.type,
inputs=inputs,

@ -71,6 +71,7 @@ class AdamW(Adam):
gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
@ -111,6 +112,7 @@ class AdamW(Adam):
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
assert learning_rate is not None
assert beta1 is not None
@ -138,7 +140,8 @@ class AdamW(Adam):
epsilon=epsilon,
grad_clip=grad_clip,
name=name,
lazy_mode=lazy_mode)
lazy_mode=lazy_mode,
multi_precision=multi_precision)
def _append_decoupled_weight_decay(self, block, param_and_grad):
"""

@ -128,21 +128,6 @@ class Momentum(Optimizer):
self.helper = LayerHelper(self.__class__.__name__)
for p in parameters:
self._add_accumulator(self._velocity_acc_str, p)
else:
all_parameters = fluid.default_main_program().global_block(
).all_parameters()
self.helper = LayerHelper(self.__class__.__name__)
for p in all_parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._velocity_acc_str, master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_accumulator(self._velocity_acc_str, p)
def _create_master_weight(self, param):
assert isinstance(self.helper, LayerHelper)
@ -190,8 +175,21 @@ class Momentum(Optimizer):
return self._accumulators[name][target_name]
def _create_accumulators(self, block, parameters):
if framework.in_dygraph_mode():
return
assert isinstance(block, framework.Block)
# create accumulator in init func, so no implementation here
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._velocity_acc_str, master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_accumulator(self._velocity_acc_str, p)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)

Loading…
Cancel
Save