Update amp_check_finite_and_scale_op and add an updating_loss_scaling op for static graph amp training. (#26240)

* update amp_check_finite_and_scale_op for static_amp.

* use amp_check_finite_and_scale in static graph amp.

* update grads to zero when grads own infinite values(as for amp_checkout_finite_and_scale op).

* add update_loss_scaling op in cpp.

* add update_loss_scaling_op unit test.

* update the doc of the check_finite_and_unscale op

* Update the process of gradients updating skipping if the gradients have infinite values.

* update the way to zero grads.

* update test_update_loss_scaling_op.py

* add log info when find infinite grads.

* add the unit test for UpdateLossScaling Layer.
disable_ut_1
Zhen Wang 4 years ago committed by GitHub
parent 2b6a5793fe
commit d708b21074
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,104 +0,0 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class AmpCheckFiniteAndScaleOp : public framework::OperatorWithKernel {
public:
AmpCheckFiniteAndScaleOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
"amp_check_finite_and_unscale");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"amp_check_finite_and_unscale");
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(amp_check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->SetOutputDim("FoundInfinite", {1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class AmpCheckFiniteAndScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensors) The input tensors of amp_check_finite_and_scale operator.")
.AsDuplicable();
AddInput("Scale",
"(Tensor) 1-dim tensor, the scale of amp_check_finite_and_scale "
"operator.");
AddOutput("Out",
"(Tensors) The scaled output tensor of "
"amp_check_finite_and_unscale operator.")
.AsDuplicable();
AddOutput("FoundInfinite",
"(Tensor) 1-dim tensor, contains a bool scalar, which indicates "
"if there there is infinite or nan item in input X.");
AddComment(R"DOC(
amp_check_finite_and_scale operator.
Check if input X contains all finite data, if yes, scale it by input Scale.
$$Out = X * scale$$
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False).
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
amp_check_finite_and_scale, ops::AmpCheckFiniteAndScaleOp,
ops::AmpCheckFiniteAndScaleOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
amp_check_finite_and_scale,
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CPUDeviceContext,
float>,
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CPUDeviceContext,
double>);

@ -1,66 +0,0 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/isfinite_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AmpCheckFiniteAndScaleKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
const auto* scale = ctx.Input<framework::Tensor>("Scale");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
const T* scale_data = scale->data<T>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
*found_inf_data = false;
framework::Tensor is_finite =
ctx.AllocateTmpTensor<bool, DeviceContext>({1}, dev_ctx);
bool* is_finite_data = is_finite.template data<bool>();
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
out->mutable_data<T>(dev_ctx.GetPlace());
if (!(*found_inf_data)) {
framework::TensorIsfinite(*x, &is_finite);
if (*is_finite_data) {
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*x);
eigen_out.device(dev) = (*scale_data) * eigen_in;
} else {
*found_inf_data = true;
break;
}
}
}
return;
}
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,141 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace operators {
class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
public:
CheckFiniteAndUnscaleOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
"check_finite_and_unscale");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"check_finite_and_unscale");
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->SetOutputDim("FoundInfinite", {1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class CheckFiniteAndUnscaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensors) The input tensors of check_finite_and_unscale operator.")
.AsDuplicable();
AddInput("Scale",
"(Tensor) 1-dim tensor, the scale of check_finite_and_unscale "
"operator.");
AddOutput("Out",
"(Tensors) The scaled output tensor of "
"check_finite_and_unscale operator.")
.AsDuplicable();
AddOutput("FoundInfinite",
"(Tensor) 1-dim tensor, contains a bool scalar, which indicates "
"if there there is infinite or nan item in input X.");
AddComment(R"DOC(
check_finite_and_unscale operator.
Check if input X contains all finite data, if yes, scale it by input Scale.
$$Out = X / scale$$
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False).
)DOC");
}
};
template <typename T>
class CheckFiniteAndUnscaleCpuKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
const auto* scale = ctx.Input<framework::Tensor>("Scale");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
const T* scale_data = scale->data<T>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
*found_inf_data = false;
framework::Tensor is_finite =
ctx.AllocateTmpTensor<bool, platform::CPUDeviceContext>({1}, dev_ctx);
bool* is_finite_data = is_finite.template data<bool>();
auto& dev = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
T inverse_scale = Inverse<T>(*scale_data);
for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
out->mutable_data<T>(dev_ctx.GetPlace());
if (!(*found_inf_data)) {
framework::TensorIsfinite(*x, &is_finite);
*found_inf_data = !(*is_finite_data);
}
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*x);
if (!(*found_inf_data)) {
eigen_out.device(dev) = eigen_in * inverse_scale;
} else {
eigen_out.device(dev) = eigen_in * static_cast<T>(0);
}
}
return;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
check_finite_and_unscale, ops::CheckFiniteAndUnscaleOp,
ops::CheckFiniteAndUnscaleOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleCpuKernel<float>,
ops::CheckFiniteAndUnscaleCpuKernel<double>);

@ -14,28 +14,31 @@ limitations under the License. */
#include <cuda.h> #include <cuda.h>
#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h" #include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
__global__ void AmpCheckFiniteAndScale(const T* in, const T* scale, int num, __global__ void GpuInverse(const T* s, T* o) {
*o = Inverse<T>(*s);
}
template <typename T>
__global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num,
bool* found_inf, T* out) { bool* found_inf, T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) { if (idx < num) {
if (!isfinite(in[idx])) { if (!isfinite(in[idx])) {
*found_inf = 1; *found_inf = true;
} }
out[idx] = *found_inf ? in[idx] : in[idx] * scale[0]; out[idx] = *found_inf ? in[idx] : in[idx] * (*scale);
} }
} }
template <typename T> template <typename T>
class AmpCheckFiniteAndScaleKernel<platform::CUDADeviceContext, T> class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
: public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
@ -48,6 +51,12 @@ class AmpCheckFiniteAndScaleKernel<platform::CUDADeviceContext, T>
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace()); bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
cudaMemset(found_inf_data, false, found_inf->numel() * sizeof(bool)); cudaMemset(found_inf_data, false, found_inf->numel() * sizeof(bool));
framework::Tensor inverse_scale =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({1}, dev_ctx);
T* inverse_scale_v = inverse_scale.template data<T>();
GpuInverse<T><<<1, 1, 0, dev_ctx.stream()>>>(scale_data, inverse_scale_v);
for (size_t i = 0; i < xs.size(); ++i) { for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i]; const auto* x = xs[i];
auto* out = outs[i]; auto* out = outs[i];
@ -55,11 +64,11 @@ class AmpCheckFiniteAndScaleKernel<platform::CUDADeviceContext, T>
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = x->numel(); int num = x->numel();
int block = 512; int block = 1024;
int grid = (num + block - 1) / block; int grid = (num + block - 1) / block;
VLOG(3) << "launch kernel"; VLOG(3) << "launch kernel";
AmpCheckFiniteAndScale<T><<<grid, block, 0, dev_ctx.stream()>>>( CheckFiniteAndUnscale<T><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, scale_data, num, found_inf_data, out_data); x_data, inverse_scale_v, num, found_inf_data, out_data);
VLOG(3) << "finish kernel"; VLOG(3) << "finish kernel";
} }
} }
@ -68,9 +77,6 @@ class AmpCheckFiniteAndScaleKernel<platform::CUDADeviceContext, T>
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(check_finite_and_unscale,
amp_check_finite_and_scale, ops::CheckFiniteAndUnscaleGpuKernel<float>,
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CUDADeviceContext, ops::CheckFiniteAndUnscaleGpuKernel<double>);
float>,
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CUDADeviceContext,
double>);

@ -0,0 +1,31 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
template <typename T>
inline HOSTDEVICE T Inverse(T s) {
return 1.0 / s;
}
} // namespace operators
} // namespace paddle

@ -0,0 +1,170 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class UpdateLossScalingOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InGoodSteps"), "Input", "InGoodSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps",
"update_loss_scaling");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->SetOutputDim("LossScaling", {1});
ctx->SetOutputDim("OutGoodSteps", {1});
ctx->SetOutputDim("OutBadSteps", {1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "PrevLossScaling"),
ctx.device_context());
}
};
class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensors) The input tensors of update_loss_scaling operator.")
.AsDuplicable();
AddInput("FoundInfinite",
"(Tensor) 1-dim tensor, contains a bool scalar, which indicates "
"whether there is any infinite gradient.");
AddInput("PrevLossScaling",
"(Tensor) 1-dim tensor, previous loss scaling.");
AddInput("InGoodSteps",
"(Tensor) 1-dim tensor, accumulates good steps in which all "
"gradients are finite.");
AddInput("InBadSteps",
"(Tensor) 1-dim tensor, accumulates bad steps in which some "
"gradients are infinite.");
AddOutput("Out",
"(Tensors) The output tensor of update_loss_scaling operator.")
.AsDuplicable();
AddOutput("LossScaling", "(Tensor) 1-dim tensor, updated loss scaling.");
AddOutput("OutGoodSteps", "(Tensor) 1-dim tensor, pdated good steps.");
AddOutput("OutBadSteps", "(Tensor) 1-dim tensor, updated bad steps.");
AddAttr<int>("incr_every_n_steps",
"A value represents increasing loss scaling every n "
"consecutive steps with finite gradients.");
AddAttr<int>("decr_every_n_nan_or_inf",
"A value represents decreasing loss scaling every n "
"accumulated steps with nan or inf gradients.");
AddAttr<float>("incr_ratio",
"The multiplier to use when increasing the loss scaling.")
.AddCustomChecker([](float incr_ratio) {
PADDLE_ENFORCE_EQ(incr_ratio > 1.0f, true,
platform::errors::InvalidArgument(
"'incr_ratio' should be greater than 1, but "
"the received is %f",
incr_ratio));
});
AddAttr<float>(
"decr_ratio",
"The less-than-one-multiplier to use when decreasing loss scaling.")
.AddCustomChecker([](float decr_ratio) {
PADDLE_ENFORCE_EQ(decr_ratio > 0.0f && decr_ratio < 1.0f, true,
platform::errors::InvalidArgument(
"'incr_ratio' should be between 0 and 1, but "
"the received is %f",
decr_ratio));
});
AddComment(R"DOC(
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
)DOC");
}
};
template <typename T>
class UpdateLossScalingFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& ctx,
const bool* found_inf_data, const T* pre_loss_scaling_data,
const int* good_in_data, const int* bad_in_data,
const int incr_every_n_steps,
const int decr_every_n_nan_or_inf, const float incr_ratio,
const float decr_ratio, T* updated_loss_scaling_data,
int* good_out_data, int* bad_out_data) const {
Update<T>(found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
decr_ratio, updated_loss_scaling_data, good_out_data,
bad_out_data);
}
};
template <typename T>
class LazyZeroInputs<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& dev_ctx,
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
if (*found_inf_data) {
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = out->numel();
std::memset(out_data, 0, num * sizeof(T));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
update_loss_scaling, ops::UpdateLossScalingOp,
ops::UpdateLossScalingOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(update_loss_scaling,
ops::UpdateLossScalingKernel<CPU, float>,
ops::UpdateLossScalingKernel<CPU, double>);

@ -0,0 +1,84 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void GpuUpdateLossScaling(
const bool* found_inf_data, const T* pre_loss_scaling_data,
const int* good_in_data, const int* bad_in_data,
const int incr_every_n_steps, const int decr_every_n_nan_or_inf,
const float incr_ratio, const float decr_ratio,
T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) {
Update<T>(found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio,
updated_loss_scaling_data, good_out_data, bad_out_data);
}
template <typename T>
class UpdateLossScalingFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& dev_ctx,
const bool* found_inf_data, const T* pre_loss_scaling_data,
const int* good_in_data, const int* bad_in_data,
const int incr_every_n_steps,
const int decr_every_n_nan_or_inf, const float incr_ratio,
const float decr_ratio, T* updated_loss_scaling_data,
int* good_out_data, int* bad_out_data) const {
GpuUpdateLossScaling<T><<<1, 1, 0, dev_ctx.stream()>>>(
found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio,
updated_loss_scaling_data, good_out_data, bad_out_data);
}
};
template <typename T>
class LazyZeroInputs<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& dev_ctx,
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
const auto gpu_place =
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
bool has_inf{false};
memory::Copy(platform::CPUPlace(), &has_inf, gpu_place, found_inf_data,
sizeof(bool), dev_ctx.stream());
if (has_inf) {
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = out->numel();
cudaMemset(out_data, 0, num * sizeof(T));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using GPU = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(update_loss_scaling,
ops::UpdateLossScalingKernel<GPU, float>,
ops::UpdateLossScalingKernel<GPU, double>);

@ -0,0 +1,123 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE void Update(const bool* found_inf_data,
const T* pre_loss_scaling_data, const int* good_in_data,
const int* bad_in_data, const int incr_every_n_steps,
const int decr_every_n_nan_or_inf,
const float incr_ratio, const float decr_ratio,
T* updated_loss_scaling_data, int* good_out_data,
int* bad_out_data) {
if (*found_inf_data) {
*good_out_data = 0;
*bad_out_data = *bad_in_data + 1;
if (*bad_out_data == decr_every_n_nan_or_inf) {
T new_loss_scaling = *pre_loss_scaling_data * decr_ratio;
*updated_loss_scaling_data = new_loss_scaling < static_cast<T>(1)
? static_cast<T>(1)
: new_loss_scaling;
*bad_out_data = 0;
}
} else {
*bad_out_data = 0;
*good_out_data = *good_in_data + 1;
if (*good_out_data == incr_every_n_steps) {
T new_loss_scaling = *pre_loss_scaling_data * incr_ratio;
*updated_loss_scaling_data = std::isfinite(new_loss_scaling)
? new_loss_scaling
: *pre_loss_scaling_data;
*good_out_data = 0;
}
}
}
template <typename DeviceContext, typename T>
class UpdateLossScalingFunctor {
public:
void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data,
const T* pre_loss_scaling_data, const int* good_in_data,
const int* bad_in_data, const int incr_every_n_steps,
const int decr_every_n_nan_or_inf, const float incr_ratio,
const float decr_ratio, T* updated_loss_scaling_data,
int* good_out_data, int* bad_out_data) const;
};
template <typename DeviceContext, typename T>
class LazyZeroInputs {
public:
void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const;
};
template <typename DeviceContext, typename T>
class UpdateLossScalingKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto xs = ctx.MultiInput<framework::Tensor>("X");
const auto* found_inf = ctx.Input<Tensor>("FoundInfinite");
const auto* pre_loss_scaling = ctx.Input<Tensor>("PrevLossScaling");
const auto* good_in = ctx.Input<Tensor>("InGoodSteps");
const auto* bad_in = ctx.Input<Tensor>("InBadSteps");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* updated_loss_scaling = ctx.Output<Tensor>("LossScaling");
auto* good_out = ctx.Output<Tensor>("OutGoodSteps");
auto* bad_out = ctx.Output<Tensor>("OutBadSteps");
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
const bool* found_inf_data = found_inf->data<bool>();
const T* pre_loss_scaling_data = pre_loss_scaling->data<T>();
const int* good_in_data = good_in->data<int>();
const int* bad_in_data = bad_in->data<int>();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
T* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<T>(dev_ctx.GetPlace());
int* good_out_data = good_out->mutable_data<int>(dev_ctx.GetPlace());
int* bad_out_data = bad_out->mutable_data<int>(dev_ctx.GetPlace());
const int incr_every_n_steps = ctx.Attr<int>("incr_every_n_steps");
const int decr_every_n_nan_or_inf =
ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio");
UpdateLossScalingFunctor<DeviceContext, T>{}(
dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data,
bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data);
LazyZeroInputs<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
}
};
} // namespace operators
} // namespace paddle

@ -111,7 +111,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"fake_quantize_dequantize_moving_average_abs_max", {"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}}, {"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
{"amp_check_finite_and_scale", {"Out", "FoundInfinite"}}, {"check_finite_and_unscale", {"Out", "FoundInfinite"}},
{"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
}; };
// clang-format off // clang-format off

@ -0,0 +1,124 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable
__all__ = ['check_finite_and_unscale', 'update_loss_scaling']
def check_finite_and_unscale(x, scale, name=None):
"""
Check if input X contains all finite data, if yes, scale it by input Scale.
$$Out = X / scale$$
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False).
Args:
x(list|tuple): The input tensors of check_finite_and_unscale operator.
scale: The scale of check_finite_and_unscale operator.
"""
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x:
check_variable_and_dtype(e, "x", ['float32', 'float64'],
'check_finite_and_unscale')
helper = LayerHelper("check_finite_and_unscale", **locals())
found_inf = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'X': x, 'Scale': scale}
outputs = {'Out': x, 'FoundInfinite': found_inf}
helper.append_op(
type='check_finite_and_unscale', inputs=inputs, outputs=outputs)
return x, found_inf
def update_loss_scaling(x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
name=None):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
x(list|tuple): The input tensors of update_loss_scaling operator.
found_inf (Variable): A boolean variable indicates whether
there is any infinite gradient.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (int): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (int): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling",
['float32', 'float64'], "update_loss_scaling")
check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x:
check_variable_and_dtype(e, "x", ['float32', 'float64'],
'update_loss_scaling')
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals())
inputs = {
'X': x,
'FoundInfinite': found_inf,
'PrevLossScaling': prev_loss_scaling,
'InGoodSteps': num_good_steps,
'InBadSteps': num_bad_steps
}
outputs = {
'Out': x,
'LossScaling': prev_loss_scaling,
'OutGoodSteps': num_good_steps,
'OutBadSteps': num_bad_steps
}
attrs = {
'incr_every_n_steps': incr_every_n_steps,
'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf,
'incr_ratio': incr_ratio,
'decr_ratio': decr_ratio,
}
helper.append_op(
type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs)
return x

@ -17,9 +17,11 @@ from ... import default_startup_program
from ... import layers from ... import layers
from ... import unique_name from ... import unique_name
from . import fp16_utils from . import fp16_utils
from .fp16_utils import update_loss_scaling, rewrite_program from .fp16_utils import rewrite_program
from .fp16_utils import update_role_var_grad from .fp16_utils import update_role_var_grad
from .fp16_lists import AutoMixedPrecisionLists from .fp16_lists import AutoMixedPrecisionLists
from .amp_nn import check_finite_and_unscale
from .amp_nn import update_loss_scaling
__all__ = ["decorate"] __all__ = ["decorate"]
@ -67,10 +69,8 @@ class OptimizerWithMixedPrecision(object):
persistable=True) persistable=True)
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = layers.fill_constant( self._incr_every_n_steps = incr_every_n_steps
shape=[1], dtype='int32', value=incr_every_n_steps) self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
self._decr_every_n_nan_or_inf = layers.fill_constant(
shape=[1], dtype='int32', value=decr_every_n_nan_or_inf)
self._incr_ratio = incr_ratio self._incr_ratio = incr_ratio
self._decr_ratio = decr_ratio self._decr_ratio = decr_ratio
self._num_good_steps = layers.create_global_var( self._num_good_steps = layers.create_global_var(
@ -139,49 +139,46 @@ class OptimizerWithMixedPrecision(object):
# Change the op_role_var attr for some ops, so that gradients # Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16. # transferred across GPUs can be FP16.
update_role_var_grad(self._train_program, self._params_grads) update_role_var_grad(self._train_program, self._params_grads)
scaled_params_grads = []
for p, g in self._params_grads:
with self._train_program._optimized_guard([p, g]):
scaled_g = g / self._loss_scaling
scaled_params_grads.append([p, scaled_g])
return scaled_params_grads return self._params_grads
def apply_gradients(self, scaled_params_grads): def apply_gradients(self, params_grads):
""" """
Check scaled gradients to determine whether to update loss scaling and update Check scaled gradients to determine whether to update loss scaling and update
parameters by their scaled gradients, parameters by their scaled gradients,
Args: Args:
scaled_params_grads (list): A list of params and scaled grads. params_grads (list): A list of params and scaled grads.
Returns: Returns:
A list of optimize operators. A list of optimize operators.
""" """
if self._use_dynamic_loss_scaling: grads = [g for _, g in params_grads]
with self._train_program._optimized_guard(grads):
grads = [layers.reduce_sum(g) for [_, g] in scaled_params_grads] grads, found_inf = check_finite_and_unscale(
all_grads = layers.concat(grads) grads, self._loss_scaling, name="find_infinite_scale")
all_grads_sum = layers.reduce_sum(all_grads)
is_overall_finite = layers.isfinite(all_grads_sum)
update_loss_scaling(is_overall_finite, self._loss_scaling, if self._use_dynamic_loss_scaling:
self._num_good_steps, self._num_bad_steps, with self._train_program._optimized_guard(grads):
grads = update_loss_scaling(
grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps, self._incr_every_n_steps,
self._decr_every_n_nan_or_inf, self._incr_ratio, self._decr_every_n_nan_or_inf,
self._decr_ratio) self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
params_unscaled_grads = []
for pg, new_g in zip(params_grads, grads):
params_unscaled_grads.append((pg[0], new_g))
# apply_gradient append all ops in global block, thus we shouldn't # apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch. # apply gradient in the switch branch.
with layers.Switch() as switch: optimize_ops = self._optimizer.apply_gradients(params_unscaled_grads)
with switch.case(is_overall_finite):
pass
with switch.default():
for _, g in scaled_params_grads:
layers.assign(layers.zeros_like(g), g)
optimize_ops = self._optimizer.apply_gradients(scaled_params_grads)
return optimize_ops return optimize_ops

@ -328,77 +328,3 @@ def update_role_var_grad(main_prog, params_grads):
raise ValueError("The op {0} is not in program".format(op)) raise ValueError("The op {0} is not in program".format(op))
block.desc._remove_op(op_idx, op_idx + 1) block.desc._remove_op(op_idx, op_idx + 1)
block._sync_with_cpp() block._sync_with_cpp()
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
num_bad_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0)
with layers.Switch() as switch:
with switch.case(is_overall_finite):
should_incr_loss_scaling = layers.less_than(incr_every_n_steps,
num_good_steps + 1)
with layers.Switch() as switch1:
with switch1.case(should_incr_loss_scaling):
new_loss_scaling = prev_loss_scaling * incr_ratio
loss_scaling_is_finite = layers.isfinite(new_loss_scaling)
with layers.Switch() as switch2:
with switch2.case(loss_scaling_is_finite):
layers.assign(new_loss_scaling, prev_loss_scaling)
with switch2.default():
pass
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch1.default():
layers.increment(num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch.default():
should_decr_loss_scaling = layers.less_than(decr_every_n_nan_or_inf,
num_bad_steps + 1)
with layers.Switch() as switch3:
with switch3.case(should_decr_loss_scaling):
new_loss_scaling = prev_loss_scaling * decr_ratio
static_loss_scaling = \
layers.fill_constant(shape=[1],
dtype='float32',
value=1.0)
less_than_one = layers.less_than(new_loss_scaling,
static_loss_scaling)
with layers.Switch() as switch4:
with switch4.case(less_than_one):
layers.assign(static_loss_scaling,
prev_loss_scaling)
with switch4.default():
layers.assign(new_loss_scaling, prev_loss_scaling)
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch3.default():
layers.assign(zero_steps, num_good_steps)
layers.increment(num_bad_steps)

@ -210,12 +210,11 @@ class AmpScaler(object):
def _unscale(self, optimizer): def _unscale(self, optimizer):
if not self._enable: if not self._enable:
return return
inv_scale = 1.0 / self._scale
param_grads = [ param_grads = [
param._grad_ivar() for param in optimizer._parameter_list param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None if param._grad_ivar() is not None
] ]
core.ops.amp_check_finite_and_scale(param_grads, inv_scale, param_grads, core.ops.check_finite_and_unscale(param_grads, self._scale, param_grads,
self._found_inf) self._found_inf)
def _update(self): def _update(self):

@ -18,9 +18,9 @@ from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid import paddle.fluid as fluid
class TestAmpCheckFiniteAndScaleOp(OpTest): class TestCheckFiniteAndUnscaleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "amp_check_finite_and_scale" self.op_type = "check_finite_and_unscale"
self.init_dtype() self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
scale = np.random.random((1)).astype(self.dtype) scale = np.random.random((1)).astype(self.dtype)
@ -28,7 +28,7 @@ class TestAmpCheckFiniteAndScaleOp(OpTest):
self.inputs = {'X': [('x0', x)], 'Scale': scale} self.inputs = {'X': [('x0', x)], 'Scale': scale}
self.outputs = { self.outputs = {
'FoundInfinite': np.array([0]), 'FoundInfinite': np.array([0]),
'Out': [('out0', x * scale)], 'Out': [('out0', x / scale)],
} }
def init_dtype(self): def init_dtype(self):
@ -38,9 +38,9 @@ class TestAmpCheckFiniteAndScaleOp(OpTest):
self.check_output() self.check_output()
class TestAmpCheckFiniteAndScaleOpWithNan(OpTest): class TestCheckFiniteAndUnscaleOpWithNan(OpTest):
def setUp(self): def setUp(self):
self.op_type = "amp_check_finite_and_scale" self.op_type = "check_finite_and_unscale"
self.init_dtype() self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
x[128][128] = np.nan x[128][128] = np.nan
@ -61,9 +61,9 @@ class TestAmpCheckFiniteAndScaleOpWithNan(OpTest):
self.check_output(no_check_set=['Out']) self.check_output(no_check_set=['Out'])
class TestAmpCheckFiniteAndScaleOpWithInf(OpTest): class TestCheckFiniteAndUnscaleOpWithInf(OpTest):
def setUp(self): def setUp(self):
self.op_type = "amp_check_finite_and_scale" self.op_type = "check_finite_and_unscale"
self.init_dtype() self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
x[128][128] = np.inf x[128][128] = np.inf

@ -57,7 +57,7 @@ class TestFleetAMPOptimizer(unittest.TestCase):
ops = [op.type for op in avg_cost.block.ops] ops = [op.type for op in avg_cost.block.ops]
self.assertIn('cast', ops) self.assertIn('cast', ops)
self.assertIn('isfinite', ops) self.assertIn('check_finite_and_unscale', ops)
if __name__ == "__main__": if __name__ == "__main__":

@ -25,6 +25,7 @@ no_check_set_white_list = [
'unsqueeze2', 'unsqueeze2',
'cross_entropy2', 'cross_entropy2',
'seed', 'seed',
'amp_check_finite_and_scale', 'check_finite_and_unscale',
'update_loss_scaling',
'cudnn_lstm', 'cudnn_lstm',
] ]

Loading…
Cancel
Save