@ -24,12 +24,16 @@ namespace math {
template <typename T>
struct MaxPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start ,
const size_t end, const size_t item_dim, T* output ,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value ,
const size_t start, const size_t end ,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T max_val = static_cast<T>(-FLT_MAX);
int max_index = -1;
if (start == end) {
output[tid] = pad_value;
index[tid] = -1;
} else {
for (int i = start; i < end; ++i) {
if (max_val < input[item_dim * i + tid]) {
max_val = input[item_dim * i + tid];
@ -40,14 +44,18 @@ struct MaxPoolFunctor {
index[tid] = max_index;
}
}
}
};
template <typename T>
struct AvgPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start ,
const size_t end, const size_t item_dim, T* output ,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value ,
const size_t start, const size_t end ,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
if (start == end) {
output[tid] = pad_value;
} else {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
@ -56,14 +64,18 @@ struct AvgPoolFunctor {
output[tid] = val / static_cast<T>(end - start);
}
}
}
};
template <typename T>
struct SumPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start ,
const size_t end, const size_t item_dim, T* output ,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value ,
const size_t start, const size_t end ,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
if (start == end) {
output[tid] = pad_value;
} else {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
@ -71,14 +83,18 @@ struct SumPoolFunctor {
output[tid] = val;
}
}
}
};
template <typename T>
struct SqrtPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start ,
const size_t end, const size_t item_dim, T* output ,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value ,
const size_t start, const size_t end ,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
if (start == end) {
output[tid] = pad_value;
} else {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
@ -87,33 +103,43 @@ struct SqrtPoolFunctor {
output[tid] = val / sqrt(end - start);
}
}
}
};
template <typename T>
struct LastPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start ,
const size_t end, const size_t item_dim, T* output ,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value ,
const size_t start, const size_t end ,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
if (start == end) {
output[tid] = pad_value;
} else {
output[tid] = input[item_dim * (end - 1) + tid];
}
}
}
};
template <typename T>
struct FirstPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start ,
const size_t end, const size_t item_dim, T* output ,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value ,
const size_t start, const size_t end ,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
if (start == end) {
output[tid] = pad_value;
} else {
output[tid] = input[item_dim * start + tid];
}
}
}
};
template <typename T, typename Range_OP>
__global__ void sequence_pool_kernel(Range_OP op, const T* input,
const size_t* lod, const size_t lod_size,
const T pad_value, const size_t* lod,
const size_t lod_size,
const size_t item_dim, T* output,
int* index) {
int bid = blockIdx.x;
@ -124,16 +150,17 @@ __global__ void sequence_pool_kernel(Range_OP op, const T* input,
if (index != nullptr) {
index_offset = &index[bid * item_dim];
}
op(input, start, end, item_dim, &output[bid * item_dim], index_offset);
op(input, pad_value, start, end, item_dim, &output[bid * item_dim],
index_offset);
}
template <typename T>
class SequencePoolFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input ,
framework::Tensor* output, bool is_tes t,
framework::Tensor* index = nullptr) {
const std::string pooltype, T pad_value ,
const framework::LoDTensor& input, framework::Tensor* outpu t,
bool is_test, framework::Tensor* index = nullptr) {
auto& lod = input.lod()[0];
const size_t item_dim = output->numel() / output->dims()[0];
dim3 threads(1024, 1);
@ -141,37 +168,37 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> {
if (pooltype == "MAX") {
sequence_pool_kernel<
T, MaxPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
MaxPoolFunctor<T>(), input.data<T>(),
MaxPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), index->data<int>());
} else if (pooltype == "AVERAGE") {
sequence_pool_kernel<
T, AvgPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
AvgPoolFunctor<T>(), input.data<T>(),
AvgPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SUM") {
sequence_pool_kernel<
T, SumPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SumPoolFunctor<T>(), input.data<T>(),
SumPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SQRT") {
sequence_pool_kernel<
T, SqrtPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SqrtPoolFunctor<T>(), input.data<T>(),
SqrtPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "LAST") {
sequence_pool_kernel<
T, LastPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
LastPoolFunctor<T>(), input.data<T>(),
LastPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_kernel<
T, FirstPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
FirstPoolFunctor<T>(), input.data<T>(),
FirstPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else {