Update amp_check_finite_and_scale_op and add an updating_loss_scaling op for static graph amp training. (#26240)
* update amp_check_finite_and_scale_op for static_amp. * use amp_check_finite_and_scale in static graph amp. * update grads to zero when grads own infinite values(as for amp_checkout_finite_and_scale op). * add update_loss_scaling op in cpp. * add update_loss_scaling_op unit test. * update the doc of the check_finite_and_unscale op * Update the process of gradients updating skipping if the gradients have infinite values. * update the way to zero grads. * update test_update_loss_scaling_op.py * add log info when find infinite grads. * add the unit test for UpdateLossScaling Layer.disable_ut_1
parent
2b6a5793fe
commit
d708b21074
@ -1,104 +0,0 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class AmpCheckFiniteAndScaleOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
AmpCheckFiniteAndScaleOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
|
||||
"amp_check_finite_and_unscale");
|
||||
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
|
||||
"amp_check_finite_and_unscale");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"The input(X) and output(Out) should have same size in "
|
||||
"Operator(amp_check_finite_and_unscale), size of input(X) is %d "
|
||||
"and size of output(Out) is %d.",
|
||||
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
|
||||
auto x_dims = ctx->GetInputsDim("X");
|
||||
ctx->SetOutputsDim("Out", x_dims);
|
||||
ctx->SetOutputDim("FoundInfinite", {1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class AmpCheckFiniteAndScaleOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"X",
|
||||
"(Tensors) The input tensors of amp_check_finite_and_scale operator.")
|
||||
.AsDuplicable();
|
||||
AddInput("Scale",
|
||||
"(Tensor) 1-dim tensor, the scale of amp_check_finite_and_scale "
|
||||
"operator.");
|
||||
AddOutput("Out",
|
||||
"(Tensors) The scaled output tensor of "
|
||||
"amp_check_finite_and_unscale operator.")
|
||||
.AsDuplicable();
|
||||
AddOutput("FoundInfinite",
|
||||
"(Tensor) 1-dim tensor, contains a bool scalar, which indicates "
|
||||
"if there there is infinite or nan item in input X.");
|
||||
AddComment(R"DOC(
|
||||
amp_check_finite_and_scale operator.
|
||||
Check if input X contains all finite data, if yes, scale it by input Scale.
|
||||
|
||||
$$Out = X * scale$$
|
||||
|
||||
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
|
||||
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
|
||||
Out should not be used, and its data may not be deterministic.
|
||||
Otherwise, FoundInfinite will be 0 (False).
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
amp_check_finite_and_scale, ops::AmpCheckFiniteAndScaleOp,
|
||||
ops::AmpCheckFiniteAndScaleOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
amp_check_finite_and_scale,
|
||||
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CPUDeviceContext,
|
||||
float>,
|
||||
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
@ -1,66 +0,0 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
||||
#include "paddle/fluid/operators/isfinite_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AmpCheckFiniteAndScaleKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
const auto xs = ctx.MultiInput<framework::Tensor>("X");
|
||||
const auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
||||
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
|
||||
|
||||
const T* scale_data = scale->data<T>();
|
||||
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
|
||||
|
||||
*found_inf_data = false;
|
||||
framework::Tensor is_finite =
|
||||
ctx.AllocateTmpTensor<bool, DeviceContext>({1}, dev_ctx);
|
||||
bool* is_finite_data = is_finite.template data<bool>();
|
||||
|
||||
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
const auto* x = xs[i];
|
||||
auto* out = outs[i];
|
||||
out->mutable_data<T>(dev_ctx.GetPlace());
|
||||
if (!(*found_inf_data)) {
|
||||
framework::TensorIsfinite(*x, &is_finite);
|
||||
if (*is_finite_data) {
|
||||
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
|
||||
auto eigen_in = framework::EigenVector<T>::Flatten(*x);
|
||||
eigen_out.device(dev) = (*scale_data) * eigen_in;
|
||||
} else {
|
||||
*found_inf_data = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,141 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
CheckFiniteAndUnscaleOp(const std::string& type,
|
||||
const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
|
||||
"check_finite_and_unscale");
|
||||
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
|
||||
"check_finite_and_unscale");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"The input(X) and output(Out) should have same size in "
|
||||
"Operator(check_finite_and_unscale), size of input(X) is %d "
|
||||
"and size of output(Out) is %d.",
|
||||
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
|
||||
auto x_dims = ctx->GetInputsDim("X");
|
||||
ctx->SetOutputsDim("Out", x_dims);
|
||||
ctx->SetOutputDim("FoundInfinite", {1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class CheckFiniteAndUnscaleOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"X",
|
||||
"(Tensors) The input tensors of check_finite_and_unscale operator.")
|
||||
.AsDuplicable();
|
||||
AddInput("Scale",
|
||||
"(Tensor) 1-dim tensor, the scale of check_finite_and_unscale "
|
||||
"operator.");
|
||||
AddOutput("Out",
|
||||
"(Tensors) The scaled output tensor of "
|
||||
"check_finite_and_unscale operator.")
|
||||
.AsDuplicable();
|
||||
AddOutput("FoundInfinite",
|
||||
"(Tensor) 1-dim tensor, contains a bool scalar, which indicates "
|
||||
"if there there is infinite or nan item in input X.");
|
||||
AddComment(R"DOC(
|
||||
check_finite_and_unscale operator.
|
||||
Check if input X contains all finite data, if yes, scale it by input Scale.
|
||||
|
||||
$$Out = X / scale$$
|
||||
|
||||
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
|
||||
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
|
||||
Out should not be used, and its data may not be deterministic.
|
||||
Otherwise, FoundInfinite will be 0 (False).
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CheckFiniteAndUnscaleCpuKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
|
||||
const auto xs = ctx.MultiInput<framework::Tensor>("X");
|
||||
const auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
||||
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
|
||||
|
||||
const T* scale_data = scale->data<T>();
|
||||
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
|
||||
|
||||
*found_inf_data = false;
|
||||
framework::Tensor is_finite =
|
||||
ctx.AllocateTmpTensor<bool, platform::CPUDeviceContext>({1}, dev_ctx);
|
||||
bool* is_finite_data = is_finite.template data<bool>();
|
||||
|
||||
auto& dev = *ctx.template device_context<platform::CPUDeviceContext>()
|
||||
.eigen_device();
|
||||
|
||||
T inverse_scale = Inverse<T>(*scale_data);
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
const auto* x = xs[i];
|
||||
auto* out = outs[i];
|
||||
out->mutable_data<T>(dev_ctx.GetPlace());
|
||||
if (!(*found_inf_data)) {
|
||||
framework::TensorIsfinite(*x, &is_finite);
|
||||
*found_inf_data = !(*is_finite_data);
|
||||
}
|
||||
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
|
||||
auto eigen_in = framework::EigenVector<T>::Flatten(*x);
|
||||
if (!(*found_inf_data)) {
|
||||
eigen_out.device(dev) = eigen_in * inverse_scale;
|
||||
} else {
|
||||
eigen_out.device(dev) = eigen_in * static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
check_finite_and_unscale, ops::CheckFiniteAndUnscaleOp,
|
||||
ops::CheckFiniteAndUnscaleOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(check_finite_and_unscale,
|
||||
ops::CheckFiniteAndUnscaleCpuKernel<float>,
|
||||
ops::CheckFiniteAndUnscaleCpuKernel<double>);
|
@ -0,0 +1,31 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/isfinite_op.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
inline HOSTDEVICE T Inverse(T s) {
|
||||
return 1.0 / s;
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,170 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class UpdateLossScalingOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite",
|
||||
"update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling",
|
||||
"update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasInput("InGoodSteps"), "Input", "InGoodSteps",
|
||||
"update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps",
|
||||
"update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
|
||||
"update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling",
|
||||
"update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps",
|
||||
"update_loss_scaling");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps",
|
||||
"update_loss_scaling");
|
||||
auto x_dims = ctx->GetInputsDim("X");
|
||||
ctx->SetOutputsDim("Out", x_dims);
|
||||
ctx->SetOutputDim("LossScaling", {1});
|
||||
ctx->SetOutputDim("OutGoodSteps", {1});
|
||||
ctx->SetOutputDim("OutBadSteps", {1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "PrevLossScaling"),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensors) The input tensors of update_loss_scaling operator.")
|
||||
.AsDuplicable();
|
||||
AddInput("FoundInfinite",
|
||||
"(Tensor) 1-dim tensor, contains a bool scalar, which indicates "
|
||||
"whether there is any infinite gradient.");
|
||||
AddInput("PrevLossScaling",
|
||||
"(Tensor) 1-dim tensor, previous loss scaling.");
|
||||
AddInput("InGoodSteps",
|
||||
"(Tensor) 1-dim tensor, accumulates good steps in which all "
|
||||
"gradients are finite.");
|
||||
AddInput("InBadSteps",
|
||||
"(Tensor) 1-dim tensor, accumulates bad steps in which some "
|
||||
"gradients are infinite.");
|
||||
AddOutput("Out",
|
||||
"(Tensors) The output tensor of update_loss_scaling operator.")
|
||||
.AsDuplicable();
|
||||
AddOutput("LossScaling", "(Tensor) 1-dim tensor, updated loss scaling.");
|
||||
AddOutput("OutGoodSteps", "(Tensor) 1-dim tensor, pdated good steps.");
|
||||
AddOutput("OutBadSteps", "(Tensor) 1-dim tensor, updated bad steps.");
|
||||
AddAttr<int>("incr_every_n_steps",
|
||||
"A value represents increasing loss scaling every n "
|
||||
"consecutive steps with finite gradients.");
|
||||
AddAttr<int>("decr_every_n_nan_or_inf",
|
||||
"A value represents decreasing loss scaling every n "
|
||||
"accumulated steps with nan or inf gradients.");
|
||||
AddAttr<float>("incr_ratio",
|
||||
"The multiplier to use when increasing the loss scaling.")
|
||||
.AddCustomChecker([](float incr_ratio) {
|
||||
PADDLE_ENFORCE_EQ(incr_ratio > 1.0f, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"'incr_ratio' should be greater than 1, but "
|
||||
"the received is %f",
|
||||
incr_ratio));
|
||||
});
|
||||
AddAttr<float>(
|
||||
"decr_ratio",
|
||||
"The less-than-one-multiplier to use when decreasing loss scaling.")
|
||||
.AddCustomChecker([](float decr_ratio) {
|
||||
PADDLE_ENFORCE_EQ(decr_ratio > 0.0f && decr_ratio < 1.0f, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"'incr_ratio' should be between 0 and 1, but "
|
||||
"the received is %f",
|
||||
decr_ratio));
|
||||
});
|
||||
AddComment(R"DOC(
|
||||
Update loss scaling according to overall gradients. If all gradients is
|
||||
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
|
||||
Otherwise, loss scaling will decrease by decr_ratio after
|
||||
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class UpdateLossScalingFunctor<platform::CPUDeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& ctx,
|
||||
const bool* found_inf_data, const T* pre_loss_scaling_data,
|
||||
const int* good_in_data, const int* bad_in_data,
|
||||
const int incr_every_n_steps,
|
||||
const int decr_every_n_nan_or_inf, const float incr_ratio,
|
||||
const float decr_ratio, T* updated_loss_scaling_data,
|
||||
int* good_out_data, int* bad_out_data) const {
|
||||
Update<T>(found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data,
|
||||
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
|
||||
decr_ratio, updated_loss_scaling_data, good_out_data,
|
||||
bad_out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LazyZeroInputs<platform::CPUDeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& dev_ctx,
|
||||
const bool* found_inf_data,
|
||||
const std::vector<const framework::Tensor*>& xs,
|
||||
const std::vector<framework::Tensor*>& outs) const {
|
||||
if (*found_inf_data) {
|
||||
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
auto* out = outs[i];
|
||||
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
|
||||
int num = out->numel();
|
||||
std::memset(out_data, 0, num * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPU = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
update_loss_scaling, ops::UpdateLossScalingOp,
|
||||
ops::UpdateLossScalingOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(update_loss_scaling,
|
||||
ops::UpdateLossScalingKernel<CPU, float>,
|
||||
ops::UpdateLossScalingKernel<CPU, double>);
|
@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void GpuUpdateLossScaling(
|
||||
const bool* found_inf_data, const T* pre_loss_scaling_data,
|
||||
const int* good_in_data, const int* bad_in_data,
|
||||
const int incr_every_n_steps, const int decr_every_n_nan_or_inf,
|
||||
const float incr_ratio, const float decr_ratio,
|
||||
T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) {
|
||||
Update<T>(found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data,
|
||||
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio,
|
||||
updated_loss_scaling_data, good_out_data, bad_out_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class UpdateLossScalingFunctor<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& dev_ctx,
|
||||
const bool* found_inf_data, const T* pre_loss_scaling_data,
|
||||
const int* good_in_data, const int* bad_in_data,
|
||||
const int incr_every_n_steps,
|
||||
const int decr_every_n_nan_or_inf, const float incr_ratio,
|
||||
const float decr_ratio, T* updated_loss_scaling_data,
|
||||
int* good_out_data, int* bad_out_data) const {
|
||||
GpuUpdateLossScaling<T><<<1, 1, 0, dev_ctx.stream()>>>(
|
||||
found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data,
|
||||
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio,
|
||||
updated_loss_scaling_data, good_out_data, bad_out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LazyZeroInputs<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& dev_ctx,
|
||||
const bool* found_inf_data,
|
||||
const std::vector<const framework::Tensor*>& xs,
|
||||
const std::vector<framework::Tensor*>& outs) const {
|
||||
const auto gpu_place =
|
||||
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
|
||||
bool has_inf{false};
|
||||
memory::Copy(platform::CPUPlace(), &has_inf, gpu_place, found_inf_data,
|
||||
sizeof(bool), dev_ctx.stream());
|
||||
if (has_inf) {
|
||||
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
auto* out = outs[i];
|
||||
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
|
||||
int num = out->numel();
|
||||
cudaMemset(out_data, 0, num * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using GPU = paddle::platform::CUDADeviceContext;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(update_loss_scaling,
|
||||
ops::UpdateLossScalingKernel<GPU, float>,
|
||||
ops::UpdateLossScalingKernel<GPU, double>);
|
@ -0,0 +1,123 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/platform/errors.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
HOSTDEVICE void Update(const bool* found_inf_data,
|
||||
const T* pre_loss_scaling_data, const int* good_in_data,
|
||||
const int* bad_in_data, const int incr_every_n_steps,
|
||||
const int decr_every_n_nan_or_inf,
|
||||
const float incr_ratio, const float decr_ratio,
|
||||
T* updated_loss_scaling_data, int* good_out_data,
|
||||
int* bad_out_data) {
|
||||
if (*found_inf_data) {
|
||||
*good_out_data = 0;
|
||||
*bad_out_data = *bad_in_data + 1;
|
||||
if (*bad_out_data == decr_every_n_nan_or_inf) {
|
||||
T new_loss_scaling = *pre_loss_scaling_data * decr_ratio;
|
||||
*updated_loss_scaling_data = new_loss_scaling < static_cast<T>(1)
|
||||
? static_cast<T>(1)
|
||||
: new_loss_scaling;
|
||||
*bad_out_data = 0;
|
||||
}
|
||||
} else {
|
||||
*bad_out_data = 0;
|
||||
*good_out_data = *good_in_data + 1;
|
||||
if (*good_out_data == incr_every_n_steps) {
|
||||
T new_loss_scaling = *pre_loss_scaling_data * incr_ratio;
|
||||
*updated_loss_scaling_data = std::isfinite(new_loss_scaling)
|
||||
? new_loss_scaling
|
||||
: *pre_loss_scaling_data;
|
||||
*good_out_data = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class UpdateLossScalingFunctor {
|
||||
public:
|
||||
void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data,
|
||||
const T* pre_loss_scaling_data, const int* good_in_data,
|
||||
const int* bad_in_data, const int incr_every_n_steps,
|
||||
const int decr_every_n_nan_or_inf, const float incr_ratio,
|
||||
const float decr_ratio, T* updated_loss_scaling_data,
|
||||
int* good_out_data, int* bad_out_data) const;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LazyZeroInputs {
|
||||
public:
|
||||
void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data,
|
||||
const std::vector<const framework::Tensor*>& xs,
|
||||
const std::vector<framework::Tensor*>& outs) const;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class UpdateLossScalingKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto xs = ctx.MultiInput<framework::Tensor>("X");
|
||||
const auto* found_inf = ctx.Input<Tensor>("FoundInfinite");
|
||||
const auto* pre_loss_scaling = ctx.Input<Tensor>("PrevLossScaling");
|
||||
const auto* good_in = ctx.Input<Tensor>("InGoodSteps");
|
||||
const auto* bad_in = ctx.Input<Tensor>("InBadSteps");
|
||||
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
||||
auto* updated_loss_scaling = ctx.Output<Tensor>("LossScaling");
|
||||
auto* good_out = ctx.Output<Tensor>("OutGoodSteps");
|
||||
auto* bad_out = ctx.Output<Tensor>("OutBadSteps");
|
||||
|
||||
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"FoundInfinite must has only one element."));
|
||||
|
||||
const bool* found_inf_data = found_inf->data<bool>();
|
||||
const T* pre_loss_scaling_data = pre_loss_scaling->data<T>();
|
||||
const int* good_in_data = good_in->data<int>();
|
||||
const int* bad_in_data = bad_in->data<int>();
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
T* updated_loss_scaling_data =
|
||||
updated_loss_scaling->mutable_data<T>(dev_ctx.GetPlace());
|
||||
int* good_out_data = good_out->mutable_data<int>(dev_ctx.GetPlace());
|
||||
int* bad_out_data = bad_out->mutable_data<int>(dev_ctx.GetPlace());
|
||||
|
||||
const int incr_every_n_steps = ctx.Attr<int>("incr_every_n_steps");
|
||||
const int decr_every_n_nan_or_inf =
|
||||
ctx.Attr<int>("decr_every_n_nan_or_inf");
|
||||
const float incr_ratio = ctx.Attr<float>("incr_ratio");
|
||||
const float decr_ratio = ctx.Attr<float>("decr_ratio");
|
||||
UpdateLossScalingFunctor<DeviceContext, T>{}(
|
||||
dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data,
|
||||
bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
|
||||
decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data);
|
||||
LazyZeroInputs<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
from paddle.fluid.framework import Variable
|
||||
|
||||
__all__ = ['check_finite_and_unscale', 'update_loss_scaling']
|
||||
|
||||
|
||||
def check_finite_and_unscale(x, scale, name=None):
|
||||
"""
|
||||
Check if input X contains all finite data, if yes, scale it by input Scale.
|
||||
|
||||
$$Out = X / scale$$
|
||||
|
||||
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
|
||||
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
|
||||
Out should not be used, and its data may not be deterministic.
|
||||
Otherwise, FoundInfinite will be 0 (False).
|
||||
Args:
|
||||
x(list|tuple): The input tensors of check_finite_and_unscale operator.
|
||||
scale: The scale of check_finite_and_unscale operator.
|
||||
"""
|
||||
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
|
||||
for e in x:
|
||||
check_variable_and_dtype(e, "x", ['float32', 'float64'],
|
||||
'check_finite_and_unscale')
|
||||
|
||||
helper = LayerHelper("check_finite_and_unscale", **locals())
|
||||
found_inf = helper.create_variable_for_type_inference(dtype='bool')
|
||||
|
||||
inputs = {'X': x, 'Scale': scale}
|
||||
outputs = {'Out': x, 'FoundInfinite': found_inf}
|
||||
helper.append_op(
|
||||
type='check_finite_and_unscale', inputs=inputs, outputs=outputs)
|
||||
|
||||
return x, found_inf
|
||||
|
||||
|
||||
def update_loss_scaling(x,
|
||||
found_inf,
|
||||
prev_loss_scaling,
|
||||
num_good_steps,
|
||||
num_bad_steps,
|
||||
incr_every_n_steps,
|
||||
decr_every_n_nan_or_inf,
|
||||
incr_ratio,
|
||||
decr_ratio,
|
||||
name=None):
|
||||
"""
|
||||
Update loss scaling according to overall gradients. If all gradients is
|
||||
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
|
||||
Otherwise, loss scaling will decrease by decr_ratio after
|
||||
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
|
||||
|
||||
Args:
|
||||
x(list|tuple): The input tensors of update_loss_scaling operator.
|
||||
found_inf (Variable): A boolean variable indicates whether
|
||||
there is any infinite gradient.
|
||||
prev_loss_scaling (Variable): Previous loss scaling.
|
||||
num_good_steps (Variable): A variable accumulates good steps in which
|
||||
all gradients are finite.
|
||||
num_bad_steps (Variable): A variable accumulates bad steps in which
|
||||
some gradients are infinite.
|
||||
incr_every_n_steps (int): A variable represents increasing loss
|
||||
scaling every n consecutive steps with
|
||||
finite gradients.
|
||||
decr_every_n_nan_or_inf (int): A variable represents decreasing
|
||||
loss scaling every n accumulated
|
||||
steps with nan or inf gradients.
|
||||
incr_ratio(float): The multiplier to use when increasing the loss
|
||||
scaling.
|
||||
decr_ratio(float): The less-than-one-multiplier to use when decreasing
|
||||
loss scaling.
|
||||
"""
|
||||
|
||||
check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling",
|
||||
['float32', 'float64'], "update_loss_scaling")
|
||||
check_type(x, 'x', (tuple, list), 'update_loss_scaling')
|
||||
for e in x:
|
||||
check_variable_and_dtype(e, "x", ['float32', 'float64'],
|
||||
'update_loss_scaling')
|
||||
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
|
||||
|
||||
helper = LayerHelper("update_loss_scaling", **locals())
|
||||
|
||||
inputs = {
|
||||
'X': x,
|
||||
'FoundInfinite': found_inf,
|
||||
'PrevLossScaling': prev_loss_scaling,
|
||||
'InGoodSteps': num_good_steps,
|
||||
'InBadSteps': num_bad_steps
|
||||
}
|
||||
|
||||
outputs = {
|
||||
'Out': x,
|
||||
'LossScaling': prev_loss_scaling,
|
||||
'OutGoodSteps': num_good_steps,
|
||||
'OutBadSteps': num_bad_steps
|
||||
}
|
||||
|
||||
attrs = {
|
||||
'incr_every_n_steps': incr_every_n_steps,
|
||||
'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf,
|
||||
'incr_ratio': incr_ratio,
|
||||
'decr_ratio': decr_ratio,
|
||||
}
|
||||
|
||||
helper.append_op(
|
||||
type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs)
|
||||
|
||||
return x
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue