Rewrite check nan inf tools (#21076)
parent
019147eb8b
commit
8a0f611b64
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
// assert false when meets NAN or inf
|
||||
void CheckVarHasNanOrInf(const std::string& op_type,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
const platform::Place& place);
|
||||
|
||||
void CheckOpHasNanOrInf(const framework::OperatorBase& op,
|
||||
const framework::Scope& scope,
|
||||
const platform::Place& place);
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,189 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/nan_inf_utils.h"
|
||||
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static std::once_flag init_multi_gpu_op_var_map_flag;
|
||||
|
||||
// lazy init
|
||||
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>&
|
||||
multi_op_var2gpu_str() {
|
||||
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>
|
||||
_multi_op_var2gpu_str;
|
||||
return _multi_op_var2gpu_str;
|
||||
}
|
||||
|
||||
static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
|
||||
static std::vector<std::mutex> _multi_op_var2gpu_str_mutex;
|
||||
return _multi_op_var2gpu_str_mutex;
|
||||
}
|
||||
|
||||
static void InitMultiGPUOpVarMap() {
|
||||
int dev_count = platform::GetCUDADeviceCount();
|
||||
PADDLE_ENFORCE_GT(dev_count, 0,
|
||||
platform::errors::NotFound(
|
||||
"cuda device must > 0, now dev_count=%d", dev_count));
|
||||
|
||||
// https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex
|
||||
std::vector<std::unordered_map<std::string, memory::AllocationPtr>> tmp_multi(
|
||||
dev_count);
|
||||
std::vector<std::mutex> tmp_multi_mutex(dev_count);
|
||||
|
||||
multi_op_var2gpu_str().swap(tmp_multi);
|
||||
multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void PrintNanInfKernel(const T* value,
|
||||
const size_t numel,
|
||||
int print_num,
|
||||
char* debug_info) {
|
||||
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
__shared__ unsigned int nan_count, inf_count, num_count;
|
||||
if (threadIdx.x == 0) nan_count = inf_count = num_count = 0;
|
||||
__syncthreads;
|
||||
|
||||
for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
|
||||
unsigned int count = 0;
|
||||
if (isnan(value[i])) {
|
||||
count = atomicAdd(&nan_count, 1);
|
||||
} else if (isinf(value[i])) {
|
||||
count = atomicAdd(&inf_count, 1);
|
||||
} else {
|
||||
count = atomicAdd(&num_count, 1);
|
||||
}
|
||||
// for cuda, print in every block
|
||||
if (count < print_num) {
|
||||
printf("numel:%lu idx:%lu value:%f\n", static_cast<uint64_t>(numel),
|
||||
static_cast<uint64_t>(i), static_cast<float>(value[i]));
|
||||
}
|
||||
}
|
||||
__syncthreads;
|
||||
|
||||
if (true && threadIdx.x == 0) {
|
||||
printf("In block %d, there has %u,%u,%u nan,inf,num\n", blockIdx.x,
|
||||
nan_count, inf_count, num_count);
|
||||
PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s
|
||||
template <typename T>
|
||||
__global__ void CheckNanInfKernel(const T* value, const size_t numel,
|
||||
int print_num, char* debug_info) {
|
||||
/// step 1, judge wheater has nan or inf
|
||||
__shared__ volatile int has_nan_inf;
|
||||
if (threadIdx.x == 0) has_nan_inf = false;
|
||||
__syncthreads();
|
||||
|
||||
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
T sum = static_cast<T>(0.0);
|
||||
// Todo(wangxi). simd speed up
|
||||
for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
|
||||
sum += (value[i] - value[i]);
|
||||
}
|
||||
|
||||
if (isnan(sum) || isinf(sum)) has_nan_inf = true;
|
||||
__syncthreads();
|
||||
|
||||
/// Note. different blocks may behave differently
|
||||
if (!has_nan_inf) return;
|
||||
|
||||
PrintNanInfKernel(value, numel, print_num, debug_info);
|
||||
}
|
||||
|
||||
template <>
|
||||
template <typename T>
|
||||
void TensorCheckerVisitor<platform::CUDADeviceContext>::apply(
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type*) const {
|
||||
int print_num = 3;
|
||||
|
||||
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
|
||||
platform::DeviceContextPool::Instance().Get(tensor_.place()));
|
||||
int dev_id = boost::get<platform::CUDAPlace>(tensor_.place()).device;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
(dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()), true,
|
||||
platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
|
||||
multi_op_var2gpu_str_mutex().size()));
|
||||
|
||||
std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ + "]";
|
||||
char* gpu_str_ptr = NULL;
|
||||
|
||||
{
|
||||
auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id);
|
||||
auto& op_var2gpu_str = multi_op_var2gpu_str().at(dev_id);
|
||||
|
||||
std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
|
||||
if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert
|
||||
auto gpu_str_tensor =
|
||||
paddle::memory::Alloc(*dev_ctx, op_var.length() + 1);
|
||||
gpu_str_ptr = reinterpret_cast<char*>(gpu_str_tensor->ptr());
|
||||
|
||||
op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor));
|
||||
|
||||
auto iter = op_var2gpu_str.find(op_var);
|
||||
PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"op_var=%s should successed insert into "
|
||||
"op_var2gpu_str, but now failed",
|
||||
op_var));
|
||||
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1,
|
||||
cudaMemcpyHostToDevice, dev_ctx->stream()),
|
||||
platform::errors::External(
|
||||
"Async cudaMemcpy op_var info to gpu failed."));
|
||||
} else { // get
|
||||
auto iter = op_var2gpu_str.find(op_var);
|
||||
PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"op_var=%s should be in the op_var2gpu_str, but "
|
||||
"now can't find it",
|
||||
op_var));
|
||||
gpu_str_ptr = reinterpret_cast<char*>(iter->second->ptr());
|
||||
}
|
||||
}
|
||||
|
||||
const size_t threads = 1024;
|
||||
size_t blocks = std::min(128ul, (tensor_.numel() + threads - 1) / threads);
|
||||
CheckNanInfKernel<<<blocks, threads, 0, dev_ctx->stream()>>>(
|
||||
tensor_.data<T>(), tensor_.numel(), print_num, gpu_str_ptr);
|
||||
}
|
||||
|
||||
template <>
|
||||
void tensor_check<platform::CUDADeviceContext>(const std::string& op_type,
|
||||
const std::string& var_name,
|
||||
const framework::Tensor& tensor,
|
||||
const platform::Place& place) {
|
||||
std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);
|
||||
|
||||
TensorCheckerVisitor<platform::CUDADeviceContext> vistor(op_type, var_name,
|
||||
tensor, place);
|
||||
VisitDataType(tensor.type(), vistor);
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
template <typename DeviceContext>
|
||||
struct TensorCheckerVisitor {
|
||||
TensorCheckerVisitor(const std::string& op_type, const std::string& var_name,
|
||||
const framework::Tensor& tensor,
|
||||
const platform::Place& place)
|
||||
: op_type_(op_type),
|
||||
var_name_(var_name),
|
||||
tensor_(tensor),
|
||||
place_(place) {}
|
||||
|
||||
template <typename T>
|
||||
void apply(
|
||||
typename std::enable_if<std::is_integral<T>::value>::type* = 0) const {
|
||||
VLOG(10) << var_name_ << " need not to check, it's type is not float point";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void apply(typename std::enable_if<std::is_floating_point<T>::value>::type* =
|
||||
0) const;
|
||||
|
||||
std::string op_type_;
|
||||
std::string var_name_;
|
||||
const framework::Tensor& tensor_;
|
||||
const platform::Place& place_;
|
||||
};
|
||||
|
||||
template <typename DeviceContext>
|
||||
void tensor_check(const std::string& op_type, const std::string& var_name,
|
||||
const framework::Tensor& tensor,
|
||||
const platform::Place& place);
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
os.environ[str("FLAGS_check_nan_inf")] = str("1")
|
||||
os.environ[str("GLOG_vmodule")] = str("nan_inf_utils_detail=10")
|
||||
|
||||
import paddle.fluid.core as core
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.compat as cpt
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def generator():
|
||||
batch_size = 5
|
||||
for i in range(5):
|
||||
curr_train_x = np.random.randint(
|
||||
batch_size, size=(batch_size, 3)).astype("float32")
|
||||
if i >= 2:
|
||||
curr_train_x[0, :] = np.nan
|
||||
curr_train_x[-1, :] = np.inf
|
||||
res = []
|
||||
for i in range(batch_size):
|
||||
y = i % 3
|
||||
res.append([y])
|
||||
y_label = np.array(res).astype('int64')
|
||||
yield [curr_train_x, y_label]
|
||||
|
||||
|
||||
def net():
|
||||
x = fluid.layers.data(name="x", shape=[3], dtype='float32')
|
||||
y = fluid.layers.data(name="y", shape=[1], dtype='int64')
|
||||
|
||||
# test int64 value
|
||||
zero = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
|
||||
|
||||
# test float16 value
|
||||
fp16_zero = fluid.layers.cast(zero, dtype='float16')
|
||||
|
||||
y = y + zero
|
||||
|
||||
hidden = x
|
||||
|
||||
for i in range(2):
|
||||
hidden = fluid.layers.fc(input=hidden, size=400, act="sigmoid")
|
||||
|
||||
hidden = fluid.layers.fc(input=hidden, size=3, act=None)
|
||||
cost, y_predict = fluid.layers.softmax_with_cross_entropy(
|
||||
hidden, y, return_softmax=True)
|
||||
acc_top1 = fluid.layers.accuracy(input=y_predict, label=y, k=1)
|
||||
avg_cost = fluid.layers.mean(cost)
|
||||
|
||||
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.05)
|
||||
sgd_optimizer.minimize(avg_cost)
|
||||
return y_predict, avg_cost, acc_top1
|
||||
|
||||
|
||||
def check(use_cuda):
|
||||
main = fluid.Program()
|
||||
startup = fluid.Program()
|
||||
scope = fluid.core.Scope()
|
||||
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(main, startup):
|
||||
y_predict, avg_cost, acc_top1 = net()
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup)
|
||||
|
||||
step = 0.0
|
||||
for train_data, y_label in generator():
|
||||
outs = exe.run(
|
||||
main,
|
||||
feed={'x': train_data,
|
||||
'y': y_label},
|
||||
fetch_list=[y_predict.name, avg_cost.name, acc_top1.name])
|
||||
step += 1
|
||||
print('iter={:.0f},cost={},acc1={}'.format(step, outs[1][0],
|
||||
outs[2][0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if core.is_compiled_with_cuda():
|
||||
try:
|
||||
check(use_cuda=True)
|
||||
assert False
|
||||
except Exception as e:
|
||||
print(e)
|
||||
assert type(e) == core.EnforceNotMet
|
||||
try:
|
||||
check(use_cuda=False)
|
||||
assert False
|
||||
except Exception as e:
|
||||
print(e)
|
||||
assert type(e) == core.EnforceNotMet
|
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
class TestNanInf(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._python_interp = sys.executable
|
||||
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
|
||||
self._python_interp += " -m coverage run --branch -p"
|
||||
self._python_interp += " check_nan_inf_base.py"
|
||||
|
||||
self.env = os.environ.copy()
|
||||
|
||||
def test_nan_inf(self):
|
||||
cmd = self._python_interp
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd.split(" "),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=self.env)
|
||||
|
||||
out, err = proc.communicate()
|
||||
returncode = proc.returncode
|
||||
|
||||
print(out)
|
||||
print(err)
|
||||
|
||||
assert returncode == 0
|
||||
# in python3, type(out+err) is 'bytes', need use encode
|
||||
assert (out + err).find('find nan or inf'.encode()) != -1
|
||||
|
||||
|
||||
class TestNanInfEnv(TestNanInf):
|
||||
def setUp(self):
|
||||
super(TestNanInfEnv, self).setUp()
|
||||
# windows python have some bug with env, so need use str to pass ci
|
||||
# otherwise, "TypeError: environment can only contain strings"
|
||||
self.env[str("PADDLE_INF_NAN_SKIP_OP")] = str("mul")
|
||||
self.env[str("PADDLE_INF_NAN_SKIP_ROLE")] = str("loss")
|
||||
self.env[str("PADDLE_INF_NAN_SKIP_VAR")] = str(
|
||||
"elementwise_add:fc_0.tmp_1")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue