Add GPU Kernels of Segment Ops, support, sum, max, min, mean

Add GPU Kernels of Segment Ops,  support, sum, max, min, mean
revert-27520-disable_pr
Zhong Hui 5 years ago committed by GitHub
parent c0caf0e45f
commit 4a9d21de49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -0,0 +1,28 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/segment_pool_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
segment_pool,
ops::SegmentPoolKernel<paddle::platform::CUDADeviceContext, float>,
ops::SegmentPoolKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
segment_pool_grad,
ops::SegmentPoolGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SegmentPoolGradKernel<paddle::platform::CUDADeviceContext, double>);

@ -63,6 +63,46 @@ void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) {
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, output, static_cast<T>(0));
}
#ifdef PADDLE_WITH_CUDA
if (!cpu_place) {
Tensor length;
length.mutable_data<IndexT>(framework::make_ddim({1}),
platform::CPUPlace());
IndexT* length_data = length.data<IndexT>();
const IndexT* segment_ids = segment->data<IndexT>();
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT),
cudaMemcpyDeviceToHost));
IndexT length_host = length_data[0];
length_host++;
PADDLE_ENFORCE_GT(
length_host, 0,
platform::errors::InvalidArgument(
"Segment ids must be >= 0, but got last id %d", length_data[0]));
auto dims = input->dims();
dims[0] = static_cast<int64_t>(length_host);
output->Resize({dims});
output->mutable_data<T>(context.GetPlace());
T init_value = 0;
if (pooltype == "MAX") {
init_value = static_cast<T>(-FLT_MAX);
} else if (pooltype == "MIN") {
init_value = static_cast<T>(FLT_MAX);
}
math::SetConstant<DeviceContext, T> setconst;
auto& dev_ctx = context.template device_context<DeviceContext>();
setconst(dev_ctx, output, static_cast<T>(init_value));
// the gpu kernel of mean pool record the counts of segment_ids
if (pooltype == "MEAN") {
summed_ids = context.Output<Tensor>("SummedIds");
summed_ids->Resize({dims[0], 1});
summed_ids->mutable_data<T>(context.GetPlace());
setconst(dev_ctx, summed_ids, static_cast<T>(1e-12));
}
}
#endif
SegmentPoolFunctor<DeviceContext, T, IndexT> pool;

@ -128,5 +128,112 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
}
#endif
// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Max, unsigned long long int); // NOLINT
CUDA_ATOMIC_WRAPPER(Max, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
return CudaAtomicMax(
reinterpret_cast<unsigned long long int *>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
}
CUDA_ATOMIC_WRAPPER(Max, float) {
if (*address >= val) {
return;
}
int *const address_as_i = (int *)address;
int old = *address_as_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) >= val) {
break;
}
old = atomicCAS(address_as_i, assumed, __float_as_int(val));
} while (assumed != old);
}
CUDA_ATOMIC_WRAPPER(Max, double) {
if (*address >= val) {
return;
}
unsigned long long int *const address_as_ull =
(unsigned long long int *)address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
if (__longlong_as_double(assumed) >= val) {
break;
}
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
} while (assumed != old);
}
// For atomicMin
USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC(Min, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Min, unsigned long long int); // NOLINT
CUDA_ATOMIC_WRAPPER(Min, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
return CudaAtomicMin(
reinterpret_cast<unsigned long long int *>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
}
CUDA_ATOMIC_WRAPPER(Min, float) {
if (*address <= val) {
return;
}
int *const address_as_i = (int *)address;
int old = *address_as_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) <= val) {
break;
}
old = atomicCAS(address_as_i, assumed, __float_as_int(val));
} while (assumed != old);
}
CUDA_ATOMIC_WRAPPER(Min, double) {
if (*address <= val) {
return;
}
unsigned long long int *const address_as_ull =
(unsigned long long int *)address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
if (__longlong_as_double(assumed) <= val) {
break;
}
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
} while (assumed != old);
}
} // namespace platform
} // namespace paddle

Loading…
Cancel
Save