[NPU] Support npu kernel for amp_check_finite_and_unscale_npu op (#31457)
* Support npu kernel for amp_check_finite_and_unscale_npu op * support EnforceNotMet exception * fix exception bug * modify python unittest * precommit * update c++ unittest * fix review * fix reviewrevert-31562-mean
parent
d746197398
commit
3bf8a34c69
@ -0,0 +1,119 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
|
||||
#include "paddle/fluid/operators/npu_op_runner.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
const auto xs = ctx.MultiInput<framework::Tensor>("X");
|
||||
const auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
||||
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
|
||||
|
||||
found_inf->mutable_data<bool>(ctx.GetPlace());
|
||||
|
||||
bool found_inf_data = false;
|
||||
|
||||
auto stream =
|
||||
ctx.template device_context<paddle::platform::NPUDeviceContext>()
|
||||
.stream();
|
||||
|
||||
// step1: inverse scale(RealDiv)
|
||||
Tensor const_tensor;
|
||||
const_tensor.mutable_data<T>({1}, ctx.GetPlace());
|
||||
TensorFromVector(std::vector<T>{static_cast<T>(1.0)}, ctx.device_context(),
|
||||
&const_tensor);
|
||||
|
||||
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
|
||||
|
||||
// Inverse(1.0/scale)
|
||||
Tensor* tmp_inverse_out = const_cast<Tensor*>(scale);
|
||||
Tensor inverse_out(scale->type());
|
||||
inverse_out.Resize(scale->dims());
|
||||
inverse_out.mutable_data<T>(ctx.GetPlace());
|
||||
auto runner_inverse =
|
||||
NpuOpRunner("Div", {const_tensor, *scale}, {inverse_out}, {});
|
||||
runner_inverse.Run(stream);
|
||||
tmp_inverse_out = &inverse_out;
|
||||
|
||||
size_t x_size = xs.size();
|
||||
for (size_t i = 0; i < x_size; ++i) {
|
||||
found_inf_data = true;
|
||||
const auto* x = xs[i];
|
||||
auto* out = outs[i];
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
// step2: CheckNumerics
|
||||
// CheckNumerics runs on the Ascend AI CPU, which delivers poor
|
||||
// performance.
|
||||
Tensor check_xout(x->type());
|
||||
check_xout.Resize(x->dims());
|
||||
check_xout.mutable_data<T>(ctx.GetPlace());
|
||||
try {
|
||||
auto runner_checknumerics =
|
||||
NpuOpRunner("CheckNumerics", {*x}, {check_xout},
|
||||
{{"message", std::string("check_nan_and_inf")}});
|
||||
runner_checknumerics.Run(stream);
|
||||
} catch (platform::EnforceNotMet& exception) {
|
||||
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
|
||||
found_inf_data = true;
|
||||
} catch (...) {
|
||||
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
|
||||
found_inf_data = true;
|
||||
}
|
||||
|
||||
if (!found_inf_data) {
|
||||
// MatMul
|
||||
auto runner_matmul =
|
||||
NpuOpRunner("Mul", {*x, *tmp_inverse_out}, {*out}, {});
|
||||
runner_matmul.Run(stream);
|
||||
} else {
|
||||
// ZerosLike
|
||||
auto runner_zeroslike = NpuOpRunner("ZerosLike", {*x}, {*out}, {});
|
||||
runner_zeroslike.Run(stream);
|
||||
} // end if
|
||||
} // end for
|
||||
|
||||
// set found_inf to true
|
||||
if (found_inf_data) {
|
||||
Tensor found_inf_tensor;
|
||||
found_inf_tensor.Resize({1});
|
||||
bool* is_found_inf =
|
||||
found_inf_tensor.mutable_data<bool>(paddle::platform::CPUPlace());
|
||||
*is_found_inf = true;
|
||||
framework::TensorCopySync(found_inf_tensor, ctx.GetPlace(), found_inf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_NPU_KERNEL(check_finite_and_unscale,
|
||||
ops::CheckFiniteAndUnscaleNPUKernel<float>,
|
||||
ops::CheckFiniteAndUnscaleNPUKernel<plat::float16>);
|
||||
@ -0,0 +1,131 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace f = paddle::framework;
|
||||
namespace p = paddle::platform;
|
||||
namespace m = paddle::operators::math;
|
||||
|
||||
using Tensor = paddle::framework::Tensor;
|
||||
|
||||
USE_OP(check_finite_and_unscale);
|
||||
USE_OP_DEVICE_KERNEL(check_finite_and_unscale, NPU);
|
||||
|
||||
struct InputVars {
|
||||
std::string name;
|
||||
f::LoDTensor *tensor;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void Compare(f::Scope *scope, const p::DeviceContext &ctx) {
|
||||
const f::DDim dims = f::make_ddim({2, 2});
|
||||
auto place = ctx.GetPlace();
|
||||
|
||||
// init input
|
||||
std::vector<InputVars> input_names = {
|
||||
{"x", scope->Var("x")->GetMutable<f::LoDTensor>()},
|
||||
{"x1", scope->Var("x1")->GetMutable<f::LoDTensor>()}};
|
||||
|
||||
auto *scale = scope->Var("scale")->GetMutable<f::LoDTensor>();
|
||||
|
||||
// init output
|
||||
auto *out = scope->Var("out")->GetMutable<f::LoDTensor>();
|
||||
auto *out1 = scope->Var("out1")->GetMutable<f::LoDTensor>();
|
||||
auto *found_inf = scope->Var("found_inf")->GetMutable<f::LoDTensor>();
|
||||
|
||||
// Initialize input data
|
||||
const int num_inputs = input_names.size();
|
||||
size_t numel = static_cast<size_t>(f::product(dims));
|
||||
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
std::vector<T> init_xs;
|
||||
for (size_t j = 0; j < numel; ++j) {
|
||||
if (j == 0) {
|
||||
init_xs.push_back(static_cast<T>(NAN));
|
||||
} else {
|
||||
init_xs.push_back(static_cast<T>(j + 1));
|
||||
}
|
||||
}
|
||||
f::TensorFromVector(init_xs, ctx, input_names[i].tensor);
|
||||
input_names[i].tensor->Resize(dims);
|
||||
}
|
||||
|
||||
f::TensorFromVector(std::vector<T>{static_cast<T>(0.5)}, ctx, scale);
|
||||
|
||||
ctx.Wait();
|
||||
|
||||
// run
|
||||
f::AttributeMap attrs;
|
||||
auto op = f::OpRegistry::CreateOp(
|
||||
"check_finite_and_unscale", {{"X", {"x", "x1"}}, {"Scale", {"scale"}}},
|
||||
{{"Out", {"out", "out1"}}, {"FoundInfinite", {"found_inf"}}}, attrs);
|
||||
op->Run(*scope, place);
|
||||
ctx.Wait();
|
||||
|
||||
// out0
|
||||
std::vector<T> out_vec;
|
||||
f::TensorToVector(*out, ctx, &out_vec);
|
||||
EXPECT_EQ(out_vec.size(), static_cast<size_t>(4));
|
||||
for (size_t j = 0; j < out_vec.size(); ++j) {
|
||||
VLOG(3) << "out_vec[" << j << "]:" << out_vec[j];
|
||||
}
|
||||
|
||||
ctx.Wait();
|
||||
|
||||
// out0
|
||||
std::vector<T> out1_vec;
|
||||
f::TensorToVector(*out1, ctx, &out1_vec);
|
||||
EXPECT_EQ(out1_vec.size(), static_cast<size_t>(4));
|
||||
for (size_t j = 0; j < out1_vec.size(); ++j) {
|
||||
VLOG(3) << "out1_vec[" << j << "]:" << out1_vec[j];
|
||||
}
|
||||
|
||||
ctx.Wait();
|
||||
|
||||
// out found_inf
|
||||
Tensor found_inf_tensor;
|
||||
found_inf_tensor.Resize({1});
|
||||
bool *is_finite_data =
|
||||
found_inf_tensor.mutable_data<bool>(paddle::platform::CPUPlace());
|
||||
f::TensorCopy(*found_inf, place, &found_inf_tensor);
|
||||
EXPECT_FALSE(*is_finite_data);
|
||||
|
||||
ctx.Wait();
|
||||
}
|
||||
|
||||
TEST(check_finite_and_unscale, NPU_fp32) {
|
||||
f::Scope scope;
|
||||
p::NPUDeviceContext ctx(p::NPUPlace(0));
|
||||
Compare<float>(&scope, ctx);
|
||||
}
|
||||
|
||||
TEST(check_finite_and_unscale, NPU_fp16) {
|
||||
f::Scope scope;
|
||||
p::NPUDeviceContext ctx(p::NPUPlace(0));
|
||||
Compare<p::float16>(&scope, ctx);
|
||||
}
|
||||
@ -0,0 +1,123 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_npu(),
|
||||
"core is not compiled with NPU")
|
||||
class TestCheckFiniteAndUnscaleOp(OpTest):
|
||||
def setUp(self):
|
||||
self.set_npu()
|
||||
self.op_type = "check_finite_and_unscale"
|
||||
self.place = paddle.NPUPlace(0)
|
||||
self.init_dtype()
|
||||
x = np.random.random((1024, 1024)).astype(self.dtype)
|
||||
scale = np.random.random((1)).astype(self.dtype)
|
||||
|
||||
self.inputs = {'X': [('x0', x)], 'Scale': scale}
|
||||
self.outputs = {
|
||||
'FoundInfinite': np.array([0]),
|
||||
'Out': [('out0', x / scale)],
|
||||
}
|
||||
|
||||
def set_npu(self):
|
||||
self.__class__.use_npu = True
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = False
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_with_place(self.place, check_dygraph=False)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_npu(),
|
||||
"core is not compiled with NPU")
|
||||
class TestCheckFiniteAndUnscaleOpWithNan(OpTest):
|
||||
def setUp(self):
|
||||
self.set_npu()
|
||||
self.op_type = "check_finite_and_unscale"
|
||||
self.place = paddle.NPUPlace(0)
|
||||
self.init_dtype()
|
||||
x = np.random.random((1024, 1024)).astype(self.dtype)
|
||||
x[128][128] = np.nan
|
||||
scale = np.random.random((1)).astype(self.dtype)
|
||||
|
||||
self.inputs = {'X': [('x0', x)], 'Scale': scale}
|
||||
self.outputs = {
|
||||
'FoundInfinite': np.array([1]),
|
||||
'Out': [('out0', x)],
|
||||
}
|
||||
|
||||
def set_npu(self):
|
||||
self.__class__.use_npu = True
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = False
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
# When input contains nan, do not check the output,
|
||||
# since the output may be nondeterministic and will be discarded.
|
||||
self.check_output_with_place(
|
||||
self.place, check_dygraph=False, no_check_set=['Out'])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_npu(),
|
||||
"core is not compiled with NPU")
|
||||
class TestCheckFiniteAndUnscaleOpWithInf(OpTest):
|
||||
def setUp(self):
|
||||
self.set_npu()
|
||||
self.op_type = "check_finite_and_unscale"
|
||||
self.place = paddle.NPUPlace(0)
|
||||
self.init_dtype()
|
||||
x = np.random.random((1024, 1024)).astype(self.dtype)
|
||||
x[128][128] = np.inf
|
||||
scale = np.random.random((1)).astype(self.dtype)
|
||||
|
||||
self.inputs = {'X': [('x0', x)], 'Scale': scale}
|
||||
self.outputs = {
|
||||
'FoundInfinite': np.array([1]),
|
||||
'Out': [('out0', x)],
|
||||
}
|
||||
|
||||
def set_npu(self):
|
||||
self.__class__.use_npu = True
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = False
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
# When input contains inf, do not check the output,
|
||||
# since the output may be nondeterministic and will be discarded.
|
||||
self.check_output_with_place(
|
||||
self.place, check_dygraph=False, no_check_set=['Out'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue