Enable seq_pool op to accept len 0 input (#17284)

* Enable seq_pool op to accept len 0 input

test=develop

* Update sequence_pool's api

test=develop

* Add more unittest cases for seq_pool op

test=develop

* Remove legacy comments

test=develop

* Don't use template in op maker

test=develop
dependabot/pip/python/requests-2.20.0
Yibing Liu 6 years ago committed by GitHub
parent 90ebce9ead
commit 33d1e56506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -87,7 +87,7 @@ paddle.fluid.layers.chunk_eval (ArgSpec(args=['input', 'label', 'chunk_scheme',
paddle.fluid.layers.sequence_conv (ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)), ('document', '3d8e8f3e0e1cf520156be37605e83ccd'))
paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '8ca6121acd6d23cd8806a93f493c2e17'))
paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '37042620f9bd3a2da6e5d3138b2f724b'))
paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,)), ('document', 'a194fb80614023f543df3949fbd0d0b8'))
paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test', 'pad_value'], varargs=None, keywords=None, defaults=(False, 0.0)), ('document', 'e90a93251c52dc4e6fb34fb3991b3f82'))
paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '19ef6f9cdd27feac8a1ae060f19c10b4'))
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', 'cee673c79e3ff4582656a24e04f841e5'))
paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', 'bbd84e855e660cd1084bb71a2fd0cdaa'))

@ -36,8 +36,8 @@ template <typename T, bool is_test>
class MaxSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index) {
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output, framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
auto idx_dims = index->dims();
@ -56,6 +56,13 @@ class MaxSeqPoolFunctor {
int64_t num_seq = out_dims[0];
int64_t dim = output->numel() / num_seq;
for (int64_t i = 0; i < num_seq; ++i) {
if (starts[i] == starts[i + 1]) {
for (int64_t k = 0; k < dim; ++k) {
out_data[i * dim + k] = pad_value;
max_index[i * dim + k] = -1;
}
continue;
}
for (int64_t k = 0; k < dim; ++k) {
out_data[i * dim + k] = in_data[starts[i] * dim + k];
max_index[i * dim + k] = starts[i];
@ -77,8 +84,8 @@ template <typename T>
class MaxSeqPoolFunctor<T, true> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index) {
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output, framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1);
@ -94,6 +101,12 @@ class MaxSeqPoolFunctor<T, true> {
int64_t num_seq = out_dims[0];
int64_t dim = output->numel() / num_seq;
for (int64_t i = 0; i < num_seq; ++i) {
if (starts[i] == starts[i + 1]) {
for (int64_t k = 0; k < dim; ++k) {
out_data[i * dim + k] = pad_value;
}
continue;
}
std::memcpy(&out_data[i * dim], &in_data[starts[i] * dim],
dim * sizeof(T));
for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) {
@ -134,6 +147,7 @@ class MaxSeqPoolGradFunctor {
for (int64_t i = 0; i < num_seq; ++i) {
for (int64_t j = 0; j < dim; ++j) {
int step_id = max_index[i * dim + j];
if (step_id == -1) continue;
ig_data[step_id * dim + j] = og_data[i * dim + j];
}
}
@ -144,7 +158,7 @@ template <typename T>
class LastSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input,
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output) {
// Create pointers to input and output data
auto* in_data = input.data<T>();
@ -157,10 +171,16 @@ class LastSeqPoolFunctor {
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
// Point to the begin of next sequence
in_data += seq_len * item_size;
// Copy the last item of sequence to output
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
if (seq_len == 0) {
for (int j = 0; j < item_size; ++j) {
out_data[j] = pad_value;
}
} else {
// Point to the begin of next sequence
in_data += seq_len * item_size;
// Copy the last item of sequence to output
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
}
out_data += item_size;
}
}
@ -170,7 +190,7 @@ template <typename T>
class FirstSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input,
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output) {
// Create pointers to input and output data
auto* in_data = input.data<T>();
@ -183,10 +203,16 @@ class FirstSeqPoolFunctor {
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
// Copy the first item of sequence to output
std::memcpy(out_data, in_data, item_size * sizeof(T));
// Point to the next sequence
in_data += seq_len * item_size;
if (seq_len == 0) {
for (int j = 0; j < item_size; ++j) {
out_data[j] = pad_value;
}
} else {
// Copy the first item of sequence to output
std::memcpy(out_data, in_data, item_size * sizeof(T));
// Point to the next sequence
in_data += seq_len * item_size;
}
out_data += item_size;
}
}
@ -207,6 +233,7 @@ class SumSeqPoolGradFunctor {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (h == 0) continue;
int64_t in_offset = lod[i] * in_w;
const T* out_pos = out_g_data + i * out_w;
T* in_pos = in_g_data + in_offset;
@ -222,27 +249,27 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public:
/* max pool has index output */
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output, bool is_test,
framework::Tensor* index = nullptr) {
const std::string pooltype, T pad_value,
const framework::LoDTensor& input, framework::Tensor* output,
bool is_test, framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
if (is_test) {
math::MaxSeqPoolFunctor<T, true> max_pool;
max_pool(context, input, output, index);
max_pool(context, input, pad_value, output, index);
} else {
math::MaxSeqPoolFunctor<T, false> max_pool;
max_pool(context, input, output, index);
max_pool(context, input, pad_value, output, index);
}
return;
}
if (pooltype == "LAST") {
math::LastSeqPoolFunctor<T> last_pool;
last_pool(context, input, output);
last_pool(context, input, pad_value, output);
return;
}
if (pooltype == "FIRST") {
math::FirstSeqPoolFunctor<T> first_pool;
first_pool(context, input, output);
first_pool(context, input, pad_value, output);
return;
}
@ -260,7 +287,13 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
.At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr);
if (attr.h == 0) {
for (int j = 0; j < attr.w; ++j) {
dst[j] = pad_value;
}
} else {
seqpool(src, dst, &attr);
}
dst += attr.w;
src += attr.h * attr.w;
}
@ -268,11 +301,17 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
}
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor out_t = output->Slice(i, i + 1);
int64_t w = input.numel() / input.dims()[0];
if (lod[i] == lod[i + 1]) {
for (int j = 0; j < w; ++j) {
out_t.data<T>()[j] = pad_value;
}
continue;
}
Tensor in_t =
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Tensor out_t = output->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t w = input.numel() / input.dims()[0];
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
@ -316,6 +355,7 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
if (lod[i] == lod[i + 1]) continue;
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1]));
auto out_g_t = out_grad.Slice(i, i + 1);

@ -24,96 +24,122 @@ namespace math {
template <typename T>
struct MaxPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value,
const size_t start, const size_t end,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T max_val = static_cast<T>(-FLT_MAX);
int max_index = -1;
for (int i = start; i < end; ++i) {
if (max_val < input[item_dim * i + tid]) {
max_val = input[item_dim * i + tid];
max_index = i;
if (start == end) {
output[tid] = pad_value;
index[tid] = -1;
} else {
for (int i = start; i < end; ++i) {
if (max_val < input[item_dim * i + tid]) {
max_val = input[item_dim * i + tid];
max_index = i;
}
}
output[tid] = max_val;
index[tid] = max_index;
}
output[tid] = max_val;
index[tid] = max_index;
}
}
};
template <typename T>
struct AvgPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value,
const size_t start, const size_t end,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
if (start == end) {
output[tid] = pad_value;
} else {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
// end, start is lod, so end - start != 0
output[tid] = val / static_cast<T>(end - start);
}
// end, start is lod, so end - start != 0
output[tid] = val / static_cast<T>(end - start);
}
}
};
template <typename T>
struct SumPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value,
const size_t start, const size_t end,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
if (start == end) {
output[tid] = pad_value;
} else {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
output[tid] = val;
}
output[tid] = val;
}
}
};
template <typename T>
struct SqrtPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value,
const size_t start, const size_t end,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
if (start == end) {
output[tid] = pad_value;
} else {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
// end, start is lod, so end - start != 0
output[tid] = val / sqrt(end - start);
}
// end, start is lod, so end - start != 0
output[tid] = val / sqrt(end - start);
}
}
};
template <typename T>
struct LastPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value,
const size_t start, const size_t end,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
output[tid] = input[item_dim * (end - 1) + tid];
if (start == end) {
output[tid] = pad_value;
} else {
output[tid] = input[item_dim * (end - 1) + tid];
}
}
}
};
template <typename T>
struct FirstPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
HOSTDEVICE void operator()(const T* input, const T pad_value,
const size_t start, const size_t end,
const size_t item_dim, T* output, int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
output[tid] = input[item_dim * start + tid];
if (start == end) {
output[tid] = pad_value;
} else {
output[tid] = input[item_dim * start + tid];
}
}
}
};
template <typename T, typename Range_OP>
__global__ void sequence_pool_kernel(Range_OP op, const T* input,
const size_t* lod, const size_t lod_size,
const T pad_value, const size_t* lod,
const size_t lod_size,
const size_t item_dim, T* output,
int* index) {
int bid = blockIdx.x;
@ -124,16 +150,17 @@ __global__ void sequence_pool_kernel(Range_OP op, const T* input,
if (index != nullptr) {
index_offset = &index[bid * item_dim];
}
op(input, start, end, item_dim, &output[bid * item_dim], index_offset);
op(input, pad_value, start, end, item_dim, &output[bid * item_dim],
index_offset);
}
template <typename T>
class SequencePoolFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output, bool is_test,
framework::Tensor* index = nullptr) {
const std::string pooltype, T pad_value,
const framework::LoDTensor& input, framework::Tensor* output,
bool is_test, framework::Tensor* index = nullptr) {
auto& lod = input.lod()[0];
const size_t item_dim = output->numel() / output->dims()[0];
dim3 threads(1024, 1);
@ -141,37 +168,37 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> {
if (pooltype == "MAX") {
sequence_pool_kernel<
T, MaxPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
MaxPoolFunctor<T>(), input.data<T>(),
MaxPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), index->data<int>());
} else if (pooltype == "AVERAGE") {
sequence_pool_kernel<
T, AvgPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
AvgPoolFunctor<T>(), input.data<T>(),
AvgPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SUM") {
sequence_pool_kernel<
T, SumPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SumPoolFunctor<T>(), input.data<T>(),
SumPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SQRT") {
sequence_pool_kernel<
T, SqrtPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SqrtPoolFunctor<T>(), input.data<T>(),
SqrtPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "LAST") {
sequence_pool_kernel<
T, LastPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
LastPoolFunctor<T>(), input.data<T>(),
LastPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_kernel<
T, FirstPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
FirstPoolFunctor<T>(), input.data<T>(),
FirstPoolFunctor<T>(), input.data<T>(), pad_value,
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else {

@ -27,8 +27,9 @@ class SequencePoolFunctor {
public:
/* max pool has index output */
void operator()(const DeviceContext& context, const std::string pooltype,
const framework::LoDTensor& input, framework::Tensor* output,
bool is_test = false, framework::Tensor* index = nullptr);
T pad_value, const framework::LoDTensor& input,
framework::Tensor* output, bool is_test = false,
framework::Tensor* index = nullptr);
};
template <typename DeviceContext, typename T>

@ -57,6 +57,9 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.")
.SetDefault("AVERAGE")
.InEnum({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"});
AddAttr<float>("pad_value",
"(float, default 0.0) The value to pad for empty sequence.")
.SetDefault(0.0);
AddComment(R"DOC(
Sequence Pool Operator.
@ -69,6 +72,8 @@ It supports six pooling types:
5. FIRST: Out[i] = first instance in i-th sequence X[i]
6. MAX: $$Out[i] = max(X_i)$$
and for the empty sequence Out[i] = attr(pad_value).
The following example explains how this works:
For a mini-batch of 3 variable-length sentences,
containing 2, 3, and 2 time-steps:

@ -32,6 +32,7 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
T pad_value = static_cast<T>(context.Attr<float>("pad_value"));
auto dims = in->dims();
auto lod = in->lod();
@ -58,8 +59,8 @@ class SequencePoolKernel : public framework::OpKernel<T> {
index->mutable_data<int>(context.GetPlace());
}
math::SequencePoolFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *in, out,
is_test, index);
pool(context.template device_context<DeviceContext>(), pooltype, pad_value,
*in, out, is_test, index);
}
};

@ -2346,7 +2346,7 @@ def conv3d(input,
return helper.append_activation(pre_act)
def sequence_pool(input, pool_type, is_test=False):
def sequence_pool(input, pool_type, is_test=False, pad_value=0.0):
"""
This function add the operator for sequence pooling.
It pools features of all time-steps of each instance, and is applied
@ -2361,29 +2361,32 @@ def sequence_pool(input, pool_type, is_test=False):
.. code-block:: text
x is a 1-level LoDTensor:
x.lod = [[2, 3, 2]]
x is a 1-level LoDTensor and **pad_value** = 0.0:
x.lod = [[2, 3, 2, 0]]
x.data = [1, 3, 2, 4, 6, 5, 1]
x.dims = [7, 1]
then output is a Tensor:
out.dim = [3, 1]
out.dim = [4, 1]
with condition len(x.lod[-1]) == out.dims[0]
for different pool_type:
average: out.data = [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
sum : out.data = [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
average: out.data = [2, 4, 3, 0.0], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
sum : out.data = [4, 12, 6, 0.0], where 4=1+3, 12=2+4+6, 6=5+1
sqrt : out.data = [2.82, 6.93, 4.24, 0.0], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
last : out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
first : out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
max : out.data = [3, 6, 5, 0.0], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
last : out.data = [3, 6, 1, 0.0], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
first : out.data = [1, 2, 5, 0.0], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
and all above 0.0 = **pad_value**.
Args:
input(variable): The input variable which is a LoDTensor.
input (variable): The input variable which is a LoDTensor.
pool_type (string): The pooling type of sequence_pool.
It supports average, sum, sqrt and max.
is_test(bool, Default False): Used distinguish training from scoring mode.
is_test (bool): Used to distinguish training from scoring mode. Default False.
pad_value (float): Used to pad the pooling result for empty input sequence.
Returns:
The sequence pooling variable which is a Tensor.
@ -2392,6 +2395,8 @@ def sequence_pool(input, pool_type, is_test=False):
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
avg_x = fluid.layers.sequence_pool(input=x, pool_type='average')
@ -2413,8 +2418,11 @@ def sequence_pool(input, pool_type, is_test=False):
inputs={"X": input},
outputs={"Out": pool_out,
"MaxIndex": max_index},
attrs={"pooltype": pool_type.upper(),
"is_test": is_test})
attrs={
"pooltype": pool_type.upper(),
"is_test": is_test,
"pad_value": pad_value
})
# when pool_type is max, variable max_index is initialized,
# so we stop the gradient explicitly here

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save