Feature/add amp_checkout_finite_and_scale op (#24875)
* add amp_check_finite_and_scale op, test=develop * add cpu kernel, test=develop * use bool, test=develop * follow comments, test=developrevert-24981-add_device_attr_for_regulization
parent
576d68083e
commit
1e818158f5
@ -0,0 +1,2 @@
|
||||
include(operators)
|
||||
register_operators()
|
@ -0,0 +1,103 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class AmpCheckFiniteAndScaleOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
AmpCheckFiniteAndScaleOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
|
||||
"amp_check_finite_and_unscale");
|
||||
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
|
||||
"amp_check_finite_and_unscale");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"The input(X) and output(Out) should have same size in "
|
||||
"Operator(amp_check_finite_and_unscale), size of input(X) is %d "
|
||||
"and size of output(Out) is %d.",
|
||||
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
|
||||
auto x_dims = ctx->GetInputsDim("X");
|
||||
ctx->SetOutputsDim("Out", x_dims);
|
||||
ctx->SetOutputDim("FoundInfinite", {1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class AmpCheckFiniteAndScaleOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"X",
|
||||
"(Tensors) The input tensors of amp_check_finite_and_scale operator.")
|
||||
.AsDuplicable();
|
||||
AddInput("Scale",
|
||||
"(Tensor) 1-dim tensor, the scale of amp_check_finite_and_scale "
|
||||
"operator.");
|
||||
AddOutput("Out",
|
||||
"(Tensors) The scaled output tensor of "
|
||||
"amp_check_finite_and_unscale operator.")
|
||||
.AsDuplicable();
|
||||
AddOutput("FoundInfinite",
|
||||
"(Tensor) 1-dim tensor, contains a int scalar, which indicates "
|
||||
"if there there is infinite or nan item in input X.");
|
||||
AddComment(R"DOC(
|
||||
amp_check_finite_and_scale operator.
|
||||
Check if input X contains all finite data, if yes, scale it by input Scale.
|
||||
|
||||
$$Out = X * scale$$
|
||||
|
||||
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
|
||||
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
|
||||
Out should not be used, and its data may not be deterministic.
|
||||
Otherwise, FoundInfinite will be 0 (False).
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
amp_check_finite_and_scale, ops::AmpCheckFiniteAndScaleOp,
|
||||
ops::AmpCheckFiniteAndScaleOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
amp_check_finite_and_scale,
|
||||
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CPUDeviceContext,
|
||||
float>,
|
||||
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
@ -0,0 +1,75 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <cuda.h>
|
||||
#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h"
|
||||
#include "paddle/fluid/platform/float16.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void AmpCheckFiniteAndScale(const T* in, const T* scale, int num,
|
||||
int* found_inf, T* out) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
if (idx < num) {
|
||||
if (!std::isfinite(in[idx])) {
|
||||
*found_inf = 1;
|
||||
}
|
||||
out[idx] = *found_inf ? in[idx] : in[idx] * scale[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class AmpCheckFiniteAndScaleKernel<platform::CUDADeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
const auto xs = ctx.MultiInput<framework::Tensor>("X");
|
||||
const auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
||||
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
|
||||
|
||||
const T* scale_data = scale->data<T>();
|
||||
int* found_inf_data = found_inf->mutable_data<int>(dev_ctx.GetPlace());
|
||||
cudaMemset(found_inf_data, false, found_inf->numel() * sizeof(bool));
|
||||
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
const auto* x = xs[i];
|
||||
auto* out = outs[i];
|
||||
const T* x_data = x->data<T>();
|
||||
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
|
||||
|
||||
int num = x->numel();
|
||||
int block = 512;
|
||||
int grid = (num + block - 1) / block;
|
||||
VLOG(3) << "launch kernel";
|
||||
AmpCheckFiniteAndScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
|
||||
x_data, scale_data, num, found_inf_data, out_data);
|
||||
VLOG(3) << "finish kernel";
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
amp_check_finite_and_scale,
|
||||
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CUDADeviceContext,
|
||||
float>,
|
||||
ops::AmpCheckFiniteAndScaleKernel<paddle::platform::CUDADeviceContext,
|
||||
double>);
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
||||
#include "paddle/fluid/operators/isfinite_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AmpCheckFiniteAndScaleKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
const auto xs = ctx.MultiInput<framework::Tensor>("X");
|
||||
const auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
||||
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
|
||||
|
||||
const T* scale_data = scale->data<T>();
|
||||
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
|
||||
|
||||
*found_inf_data = false;
|
||||
framework::Tensor is_finite =
|
||||
ctx.AllocateTmpTensor<bool, DeviceContext>({1}, dev_ctx);
|
||||
bool* is_finite_data = is_finite.template data<bool>();
|
||||
|
||||
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
const auto* x = xs[i];
|
||||
auto* out = outs[i];
|
||||
out->mutable_data<T>(dev_ctx.GetPlace());
|
||||
if (!(*found_inf_data)) {
|
||||
framework::TensorIsfinite(*x, &is_finite);
|
||||
if (*is_finite_data) {
|
||||
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
|
||||
auto eigen_in = framework::EigenVector<T>::Flatten(*x);
|
||||
eigen_out.device(dev) = (*scale_data) * eigen_in;
|
||||
} else {
|
||||
*found_inf_data = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class TestAmpCheckFiniteAndScaleOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "amp_check_finite_and_scale"
|
||||
self.init_dtype()
|
||||
x = np.random.random((1024, 1024)).astype(self.dtype)
|
||||
scale = np.random.random((1)).astype(self.dtype)
|
||||
|
||||
self.inputs = {'X': [('x0', x)], 'Scale': scale}
|
||||
self.outputs = {
|
||||
'FoundInfinite': np.array([0]),
|
||||
'Out': [('out0', x * scale)],
|
||||
}
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestAmpCheckFiniteAndScaleOpWithNan(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "amp_check_finite_and_scale"
|
||||
self.init_dtype()
|
||||
x = np.random.random((1024, 1024)).astype(self.dtype)
|
||||
x[128][128] = np.nan
|
||||
scale = np.random.random((1)).astype(self.dtype)
|
||||
|
||||
self.inputs = {'X': [('x0', x)], 'Scale': scale}
|
||||
self.outputs = {
|
||||
'FoundInfinite': np.array([1]),
|
||||
'Out': [('out0', x)],
|
||||
}
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
# When input contains nan, do not check the output,
|
||||
# since the output may be nondeterministic and will be discarded.
|
||||
self.check_output(no_check_set=['Out'])
|
||||
|
||||
|
||||
class TestAmpCheckFiniteAndScaleOpWithInf(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "amp_check_finite_and_scale"
|
||||
self.init_dtype()
|
||||
x = np.random.random((1024, 1024)).astype(self.dtype)
|
||||
x[128][128] = np.inf
|
||||
scale = np.random.random((1)).astype(self.dtype)
|
||||
|
||||
self.inputs = {'X': [('x0', x)], 'Scale': scale}
|
||||
self.outputs = {
|
||||
'FoundInfinite': np.array([1]),
|
||||
'Out': [('out0', x)],
|
||||
}
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
# When input contains inf, do not check the output,
|
||||
# since the output may be nondeterministic and will be discarded.
|
||||
self.check_output(no_check_set=['Out'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue