|
|
|
|
@ -32,6 +32,17 @@ __global__ void GpuUpdateLossScaling(
|
|
|
|
|
updated_loss_scaling_data, good_out_data, bad_out_data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void FillIf(T* data, const int64_t num, const T value,
|
|
|
|
|
const bool* has_inf) {
|
|
|
|
|
if (*has_inf) {
|
|
|
|
|
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
|
for (int i = tid; i < num; i += blockDim.x * gridDim.x) {
|
|
|
|
|
data[i] = value;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class UpdateLossScalingFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
@ -50,26 +61,20 @@ class UpdateLossScalingFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LazyZeroInputs<platform::CUDADeviceContext, T> {
|
|
|
|
|
class LazyZeros<platform::CUDADeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& dev_ctx,
|
|
|
|
|
const bool* found_inf_data,
|
|
|
|
|
const std::vector<const framework::Tensor*>& xs,
|
|
|
|
|
const std::vector<framework::Tensor*>& outs) const {
|
|
|
|
|
const auto gpu_place =
|
|
|
|
|
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
|
|
|
|
|
bool has_inf{false};
|
|
|
|
|
memory::Copy(platform::CPUPlace(), &has_inf, gpu_place, found_inf_data,
|
|
|
|
|
sizeof(bool), dev_ctx.stream());
|
|
|
|
|
dev_ctx.Wait(); // wait async copy
|
|
|
|
|
if (has_inf) {
|
|
|
|
|
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
|
|
|
|
|
for (size_t i = 0; i < xs.size(); ++i) {
|
|
|
|
|
auto* out = outs[i];
|
|
|
|
|
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
|
|
|
|
|
int num = out->numel();
|
|
|
|
|
cudaMemsetAsync(out_data, 0, num * sizeof(T), dev_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < xs.size(); ++i) {
|
|
|
|
|
auto* out = outs[i];
|
|
|
|
|
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
|
|
|
|
|
int64_t num = out->numel();
|
|
|
|
|
int block = 1024;
|
|
|
|
|
int grid = (block - 1 + num) / block;
|
|
|
|
|
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(
|
|
|
|
|
out_data, num, static_cast<T>(0), found_inf_data);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|