Add the CUDA kernel for beam_search op (#15020)

* Refine the beam_search op and test.

* A basic CUDA implementation of beam_search for small batch_size.

* Implement CUDA kernel for beam_search_op.

* Use multiple CUDA threads in the same block to select the top beam.

* Update the python api of beam_search op.

* Enable extend function in CPU kernel of beam_search op.

* Unify the CUDA codes.
test=develop

* Unify the CPU kernel of beam_search op.

* Ensure the seletced items of beam_search_op's CPU kernel sorted by scores.

* Update the description of beam_search in API.spec.

* Enable the use of CUDA kernel in beam_search op.

* Exclude the beam_search's CUDA unittest when there is no CUDA gpu, and delete some debuging statements.
test=develop

* Follow comments.
test=develop

* Call the CPU kernel for beam_search op when batch_size > 4.
test=develop

* Remove the except of is_empty op in PrepareData.
test=develop
inference-pre-release-gpu
Yiqun Liu 6 years ago committed by GitHub
parent ed1726eaaa
commit 3008fa1261
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -122,7 +122,7 @@ paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None,
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name'], varargs=None, keywords=None, defaults=(0, True, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))

@ -54,13 +54,14 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
if (!platform::is_cpu_place(t.place())) {
LoDTensor tt;
framework::TensorCopy(t, platform::CPUPlace(), &tt);
LoDTensor cpu_tensor;
cpu_tensor.set_lod(t.lod());
framework::TensorCopy(t, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(t.place());
dev_ctx.Wait();
os << tt;
os << cpu_tensor;
return os;
}

@ -66,7 +66,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
endif()
@ -86,7 +86,6 @@ set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,24 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
beam_search,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, int64_t>);

@ -14,187 +14,12 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/beam_search.h"
namespace paddle {
namespace operators {
/*
* This is an implementation of beam search.
*
* To explain the details, lets take machine translation task for example, in
* this task, one source sentence is translated to multiple target sentences,
* during this period, one sentence will be translated to multiple translation
* prefixes(target sentence that have not ended), in each time step a prefix
* will have some candidates, input the candidate ids and their corresponding
* scores (probabilities), it will sort and select the top beam_size candidates
* for each source sentence, and store the selected candidates's score and their
* corresponding ids to LoDTensors.
*
* A detailed example:
*
* Input
*
* ids:
* LoD (should have 2 levels)
* first level: [0, 1, 4]
* second level: [0, 1, 2, 3, 4]
*
* tensor's data
* [
* [4, 2, 5]
* [2, 1, 3]
* [3, 5, 2]
* [8, 2, 1]
* ]
*
* scores:
* LoD same as `ids`
* tensor's data
* [
* [0.5, 0.3, 0.2]
* [0.6, 0.3, 0.1]
* [0.9, 0.5, 0.1]
* [0.7, 0.5, 0.1]
* ]
*
* the inputs means that there are 2 source sentences to translate, and the
* first source has 1 prefix, the second source has 2 prefix.
*
* lets assume beam size is 2, and the beam search's output should be
* LoD
* first level:
* [0, 1, 2]
* second level:
* [0, 2, 4]
*
* id tensor's data
* [[
* 4,
* 1,
* 3,
* 8,
* ]]
*
* score tensor's data
* [[
* 0.5,
* 0.3,
* 0.9,
* 0.7
* ]]
*
* TODO all the prune operations should be in the beam search, so it is better
* to split the beam search algorithm into a sequence of smaller operators, and
* the prune operators can be inserted in this sequence.
*/
class BeamSearch {
public:
// TODO(superjom) make type customizable
using id_t = size_t;
using score_t = float;
/*
* Input the arguments that needed by this class.
*/
BeamSearch(const framework::LoDTensor& ids,
const framework::LoDTensor& scores, size_t level, size_t beam_size,
int end_id)
: beam_size_(beam_size),
ids_(&ids),
scores_(&scores),
lod_level_(level),
end_id_(end_id) {}
/*
* The main function of beam search.
*
* @selected_ids: a [None, 1]-shaped tensor with LoD.
* In a machine translation model, it might be the candidate term id sets,
* each set stored as a varience-length sequence.
* The format might be described with a two-level LoD
* - [[0 1]
* - [0 1 2]]
* - [[]
* - [0 1]]
* the first level of LoD tells that there are two source sentences. The
* second level describes the details of the candidate id set's offsets in
* the
* source sentences.
*
* @selected_scores: a LoD tensor with the same shape and LoD with
* selected_ids.
* It stores the corresponding scores of candidate ids in selected_ids.
*
* Return false if all the input tensor is empty, in machine translation task
* that means no candidates is provided, and the task will stop running.
*/
void operator()(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores);
/*
* The basic items help to sort.
*/
struct Item {
Item() {}
Item(size_t offset, size_t id, float score)
: offset(offset), id(id), score(score) {}
// offset in the higher lod level.
size_t offset;
// // prefix id in the lower lod level.
// size_t prefix;
// the candidate id
id_t id;
// the corresponding score
score_t score;
};
protected:
/*
* Prune the source sentences all branchs finished, and it is optional.
* Pruning must one step later than finishing (thus pre_ids is needed here),
* since the end tokens must be writed out.
*/
void PruneEndBeams(const framework::LoDTensor& pre_ids,
std::vector<std::vector<Item>>* items);
/*
* Transform the items into a map whose key is offset, value is the items.
* NOTE low performance.
*/
std::vector<std::vector<Item>> ToMap(
const std::vector<std::vector<Item>>& inputs, size_t element_num);
/*
* For each source, select top beam_size records.
*/
std::vector<std::vector<Item>> SelectTopBeamSizeItems(
const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores);
/*
* Get the items of next source sequence, return false if no remaining items.
*/
bool NextItemSet(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
std::vector<Item>* items);
private:
size_t beam_size_;
const framework::LoDTensor* ids_;
const framework::LoDTensor* scores_;
size_t lod_level_{0};
size_t sent_offset_{0};
int end_id_{0};
};
std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
std::string ItemToString(const BeamSearch::Item& item);
template <typename DeviceContext, typename T>
class BeamSearchOpKernel : public framework::OpKernel<T> {
public:
@ -203,7 +28,7 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
auto* scores = context.Input<framework::LoDTensor>("scores");
auto* pre_ids = context.Input<framework::LoDTensor>("pre_ids");
auto* pre_scores = context.Input<framework::LoDTensor>("pre_scores");
PADDLE_ENFORCE_NOT_NULL(ids);
PADDLE_ENFORCE_NOT_NULL(scores);
PADDLE_ENFORCE_NOT_NULL(pre_ids);
PADDLE_ENFORCE_NOT_NULL(pre_scores);
@ -211,14 +36,20 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
size_t level = context.Attr<int>("level");
size_t beam_size = context.Attr<int>("beam_size");
int end_id = context.Attr<int>("end_id");
BeamSearch alg(*ids, *scores, level, beam_size, end_id);
bool is_accumulated = context.Attr<bool>("is_accumulated");
auto selected_ids = context.Output<framework::LoDTensor>("selected_ids");
auto selected_scores =
context.Output<framework::LoDTensor>("selected_scores");
PADDLE_ENFORCE_NOT_NULL(selected_ids);
PADDLE_ENFORCE_NOT_NULL(selected_scores);
alg(*pre_ids, *pre_scores, selected_ids, selected_scores);
math::BeamSearchFunctor<DeviceContext, T> alg;
alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores,
ids, scores, selected_ids, selected_scores, level, beam_size, end_id,
is_accumulated);
}
};
} // namespace operators
} // namespace paddle

@ -1,92 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include <gtest/gtest.h>
#include <vector>
namespace paddle {
namespace test {
using std::vector;
using framework::LoDTensor;
using framework::LoD;
using operators::BeamSearch;
using paddle::platform::CPUPlace;
using std::cout;
using std::endl;
void CreateInput(LoDTensor* ids, LoDTensor* scores) {
LoD lod;
vector<size_t> level0({0, 2, 4});
vector<size_t> level1({0, 1, 2, 3, 4});
lod.push_back(level0);
lod.push_back(level1);
ids->set_lod(lod);
scores->set_lod(lod);
auto dims = framework::make_ddim(vector<int64_t>({4, 3}));
ids->Resize(dims);
scores->Resize(dims);
CPUPlace place;
auto* ids_data = ids->mutable_data<int64_t>(place);
auto* scores_data = scores->mutable_data<float>(place);
vector<int64_t> _ids({4, 2, 5, 2, 1, 3, 3, 5, 2, 8, 2, 1});
vector<float> _scores(
{0.5f, 0.3f, 0.2f, 0.6f, 0.3f, 0.1f, 0.9f, 0.5f, 0.1f, 0.7f, 0.5f, 0.1f});
for (int i = 0; i < 12; i++) {
ids_data[i] = _ids[i];
scores_data[i] = _scores[i];
}
}
// It seems that beam_search_op has bugs.
TEST(DISABLED_beam_search_op, run) {
CPUPlace place;
LoDTensor ids, scores;
CreateInput(&ids, &scores);
LoDTensor pre_ids;
pre_ids.Resize(framework::make_ddim(vector<int64_t>(4, 1)));
for (int i = 0; i < 4; i++) {
pre_ids.mutable_data<int64_t>(place)[i] = i + 1;
}
LoDTensor pre_scores;
pre_scores.Resize(framework::make_ddim(vector<int64_t>(4, 1)));
for (int i = 0; i < 4; i++) {
pre_scores.mutable_data<float>(place)[i] = 0.1 * (i + 1);
}
BeamSearch beamsearch(ids, scores, (size_t)0, (size_t)2, 0);
LoDTensor sids, sscores;
beamsearch(pre_ids, pre_scores, &sids, &sscores);
LOG(INFO) << "score: " << sscores << endl;
ASSERT_EQ(sids.lod(), sscores.lod());
vector<int> tids({4, 2, 3, 8});
vector<float> tscores({0.5f, 0.6f, 0.9f, 0.7f});
for (int i = 0; i < 4; i++) {
ASSERT_EQ(tids[i], sids.data<int64_t>()[i]);
ASSERT_EQ(tscores[i], sscores.data<float>()[i]);
}
}
} // namespace test
} // namespace paddle

@ -87,8 +87,8 @@ class BprLossGradientOpKernel : public framework::OpKernel<T> {
auto* label = ctx.Input<Tensor>("Label");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
const int step_size = x->dims()[0];
const int num_classes = x->dims()[1];
const size_t step_size = static_cast<size_t>(x->dims()[0]);
const size_t num_classes = static_cast<size_t>(x->dims()[1]);
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* dy_data = dy->data<T>();
const T* x_data = x->data<T>();

@ -54,6 +54,7 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
math_library(beam_search DEPS math_function)
math_library(matrix_bit_code)
@ -68,6 +69,7 @@ cc_test(im2col_test SRCS im2col_test.cc DEPS im2col)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
cc_test(beam_search_test SRCS beam_search_test.cc DEPS beam_search)
if(WITH_GPU)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu.cc DEPS selected_rows_functor math_function)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,119 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/*
* This is an implementation of beam search.
*
* To explain the details, lets take machine translation task for example, in
* this task, one source sentence is translated to multiple target sentences,
* during this period, one sentence will be translated to multiple translation
* prefixes(target sentence that have not ended), in each time step a prefix
* will have some candidates, input the candidate ids and their corresponding
* scores (probabilities), it will sort and select the top beam_size candidates
* for each source sentence, and store the selected candidates's score and their
* corresponding ids to LoDTensors.
*
* A detailed example:
*
* Input
*
* ids:
* - LoD (should have 2 levels)
* - first level: [0, 1, 4]
* - second level: [0, 1, 2, 3, 4]
* - tensor's data:
* [[4, 2, 5]
* [2, 1, 3]
* [3, 5, 2]
* [8, 2, 1]]
*
* scores:
* - LoD same as `ids`
* - tensor's data
* [[0.5, 0.3, 0.2]
* [0.6, 0.3, 0.1]
* [0.9, 0.5, 0.1]
* [0.7, 0.5, 0.1]]
*
* The inputs means that there are 2 source sentences to translate, and the
* first source has 1 prefix, the second source has 2 prefix.
*
* Lets assume beam size is 2, and the beam search's output should be
* - LoD
* - first level: [0, 1, 2]
* - second level: [0, 2, 4]
* - id tensor's data
* [[4,
* 1,
* 3,
* 8]]
* - score tensor's data
* [[0.5,
* 0.3,
* 0.9,
* 0.7]]
*
* TODO all the prune operations should be in the beam search, so it is better
* to split the beam search algorithm into a sequence of smaller operators, and
* the prune operators can be inserted in this sequence.
*/
template <typename DeviceContext, typename T>
class BeamSearchFunctor {
public:
/*
* The main function of beam search.
*
* @selected_ids: a [None, 1]-shaped tensor with LoD.
* In a machine translation model, it might be the candidate term id sets,
* each set stored as a varience-length sequence.
* The format might be described with a two-level LoD
* - [[0 1],
* [0 1 2]]
* - [[]
* [0 1]]
* the first level of LoD tells that there are two source sentences. The
* second level describes the details of the candidate id set's offsets in
* the source sentences.
*
* @selected_scores: a LoD tensor with the same shape and LoD with
* selected_ids.
* It stores the corresponding scores of candidate ids in selected_ids.
*
* Return false if all the input tensor is empty, in machine translation task
* that means no candidates is provided, and the task will stop running.
*/
void operator()(const DeviceContext& context,
const framework::LoDTensor* pre_ids,
const framework::LoDTensor* pre_scores,
const framework::LoDTensor* ids,
const framework::LoDTensor* scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores, size_t level,
size_t beam_size, int end_id, bool is_accumulated);
};
} // namespace math
} // namespace operators
} // namespace paddle

@ -0,0 +1,141 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/beam_search.h"
#include <gtest/gtest.h>
#include <vector>
void PrepareCPUTensors(paddle::framework::LoDTensor* ids,
paddle::framework::LoDTensor* scores,
paddle::framework::LoDTensor* pre_ids,
paddle::framework::LoDTensor* pre_scores) {
// lod
paddle::framework::LoD lod;
std::vector<size_t> level0({0, 2, 4});
std::vector<size_t> level1({0, 1, 2, 3, 4});
lod.push_back(level0);
lod.push_back(level1);
ids->set_lod(lod);
scores->set_lod(lod);
auto dims = paddle::framework::make_ddim({4, 3});
ids->Resize(dims);
scores->Resize(dims);
paddle::platform::CPUPlace place;
auto* ids_data = ids->mutable_data<int64_t>(place);
auto* scores_data = scores->mutable_data<float>(place);
std::vector<int64_t> ids_vec_data({4, 2, 5, 2, 1, 3, 3, 5, 2, 8, 2, 1});
std::vector<float> scores_vec_data(
{0.6f, 0.3f, 0.5f, 0.2f, 0.3f, 0.1f, 0.9f, 0.5f, 0.1f, 0.7f, 0.5f, 0.1f});
CHECK_EQ(static_cast<size_t>(ids->numel()), ids_vec_data.size());
CHECK_EQ(static_cast<size_t>(ids->numel()), scores_vec_data.size());
for (int i = 0; i < ids->numel(); i++) {
ids_data[i] = ids_vec_data[i];
scores_data[i] = scores_vec_data[i];
}
// pre_ids
pre_ids->Resize(paddle::framework::make_ddim({4, 1}));
for (int i = 0; i < 4; i++) {
pre_ids->mutable_data<int64_t>(place)[i] = i + 1;
}
// pre_scores
pre_scores->Resize(paddle::framework::make_ddim({4, 1}));
for (int i = 0; i < 4; i++) {
pre_scores->mutable_data<float>(place)[i] = 0.1 * (i + 1);
}
}
template <typename DeviceContext, typename Place>
void TestBeamSearch() {
paddle::framework::LoDTensor ids;
paddle::framework::LoDTensor scores;
paddle::framework::LoDTensor pre_ids;
paddle::framework::LoDTensor pre_scores;
auto* place = new Place();
DeviceContext* context = new DeviceContext(*place);
if (paddle::platform::is_cpu_place(*place)) {
PrepareCPUTensors(&ids, &scores, &pre_ids, &pre_scores);
} else {
paddle::framework::LoDTensor cpu_ids;
paddle::framework::LoDTensor cpu_scores;
paddle::framework::LoDTensor cpu_pre_ids;
paddle::framework::LoDTensor cpu_pre_scores;
PrepareCPUTensors(&cpu_ids, &cpu_scores, &cpu_pre_ids, &cpu_pre_scores);
TensorCopySync(cpu_ids, *place, &ids);
TensorCopySync(cpu_scores, *place, &scores);
TensorCopySync(cpu_pre_ids, *place, &pre_ids);
TensorCopySync(cpu_pre_scores, *place, &pre_scores);
ids.set_lod(cpu_ids.lod());
scores.set_lod(cpu_scores.lod());
pre_ids.set_lod(cpu_pre_ids.lod());
pre_scores.set_lod(cpu_pre_scores.lod());
}
paddle::framework::LoDTensor selected_ids;
paddle::framework::LoDTensor selected_scores;
size_t level = 0;
size_t beam_size = 2;
int end_id = 0;
paddle::operators::math::BeamSearchFunctor<DeviceContext, float> beamsearch;
beamsearch(*context, &pre_ids, &pre_scores, &ids, &scores, &selected_ids,
&selected_scores, level, beam_size, end_id, true);
ASSERT_EQ(selected_ids.lod(), selected_scores.lod());
paddle::framework::LoDTensor cpu_selected_ids;
paddle::framework::LoDTensor cpu_selected_scores;
if (paddle::platform::is_cpu_place(*place)) {
cpu_selected_ids = selected_ids;
cpu_selected_scores = selected_scores;
} else {
TensorCopySync(selected_ids, paddle::platform::CPUPlace(),
&cpu_selected_ids);
TensorCopySync(selected_scores, paddle::platform::CPUPlace(),
&cpu_selected_scores);
cpu_selected_ids.set_lod(selected_ids.lod());
cpu_selected_scores.set_lod(selected_scores.lod());
}
std::vector<int64_t> expected_ids({4, 5, 3, 8});
std::vector<float> expected_scores({0.6f, 0.5f, 0.9f, 0.7f});
for (int i = 0; i < 4; i++) {
ASSERT_EQ(expected_ids[i], cpu_selected_ids.data<int64_t>()[i]);
ASSERT_EQ(expected_scores[i], cpu_selected_scores.data<float>()[i]);
}
delete place;
delete context;
}
TEST(BeamSearch, CPU) {
TestBeamSearch<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace>();
}
#ifdef PADDLE_WITH_CUDA
TEST(BeamSearch, GPU) {
TestBeamSearch<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
}
#endif

@ -354,7 +354,7 @@ TEST(selected_rows_functor, cpu_merge_add_multi) {
auto* out_data = output->value().data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
for (size_t j = 0; j < static_cast<size_t>(row_numel); ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
}
}

@ -301,7 +301,7 @@ TEST(selected_rows_functor, gpu_merge_add) {
auto* out_data = output_cpu.data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
for (size_t j = 0; j < static_cast<size_t>(row_numel); ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
}
}

@ -66,7 +66,7 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
cpu_in_grad.set_lod(in_grad.lod());
}
EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim);
EXPECT_EQ(in_grad.numel(), static_cast<int64_t>(lod[0].back() * second_dim));
EXPECT_EQ(in_grad.lod(), lod);
if (paddle::platform::is_cpu_place(*place)) {

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
@ -30,6 +31,34 @@ namespace platform {
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
inline static int RoundToPowerOfTwo(int dim) {
if (dim > 512) {
return 1024;
} else if (dim > 256) {
return 512;
} else if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}
#define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \
case (dim): { \
constexpr auto kPowerOfTwoDim = (dim); \
__VA_ARGS__; \
} break
#define CUDA_LAUNCH_KERNEL_HELPER(...) \
CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__);
template <typename T>
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
int delta, int width = 32) {

@ -221,13 +221,17 @@ size_t GpuMaxChunkSize() {
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind, cudaStream_t stream) {
PADDLE_ENFORCE(cudaMemcpyAsync(dst, src, count, kind, stream),
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync");
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync "
"(%p -> %p, length: %d)",
src, dst, static_cast<int>(count));
}
void GpuMemcpySync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind) {
PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind),
"cudaMemcpy failed in paddle::platform::GpuMemcpySync");
"cudaMemcpy failed in paddle::platform::GpuMemcpySync (%p -> "
"%p, length: %d)",
src, dst, static_cast<int>(count));
}
void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,

@ -3875,6 +3875,7 @@ def beam_search(pre_ids,
beam_size,
end_id,
level=0,
is_accumulated=True,
name=None):
"""
Beam search is a classical algorithm for selecting candidate words in a
@ -3887,14 +3888,17 @@ def beam_search(pre_ids,
selects the top-K candidate word ids of current step from :attr:`ids`
according to their :attr:`scores` for all source sentences, where K is
:attr:`beam_size` and :attr:`ids, scores` are predicted results from the
computation cell. Additionally, :attr:`pre_ids` and :attr:`pre_scores` are
the output of beam_search at previous step, they are needed for special use
to handle ended candidate translations.
Note that the :attr:`scores` passed in should be accumulated scores, and
length penalty should be done with extra operators before calculating the
accumulated scores if needed, also suggest finding top-K before it and
using the top-K candidates following.
computation cell. If :attr:`ids` is not set, it will be calculated out
according to :attr:`scores`. Additionally, :attr:`pre_ids` and
:attr:`pre_scores` are the output of beam_search at previous step, they
are needed for special use to handle ended candidate translations.
Note that if :attr:`is_accumulated` is :attr:`True`, the :attr:`scores`
passed in should be accumulated scores. Else, the :attr:`scores` are
considered as the straightforward scores and will be transformed to the
log field and accumulated the :attr:`pre_scores` in this operator.
Length penalty should be done with extra operators before calculating the
accumulated scores if needed.
Please see the following demo for a fully beam search usage example:
@ -3924,6 +3928,8 @@ def beam_search(pre_ids,
describes how these candidates belong to the prefix. The paths
linking prefixes and selected candidates are organized and reserved
in lod.
is_accumulated(bool, default True): Whether the input :attr:`score` is
accumulated scores.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
@ -3952,8 +3958,12 @@ def beam_search(pre_ids,
end_id=end_id)
"""
helper = LayerHelper('beam_search', **locals())
score_type = scores.dtype
id_type = ids.dtype
score_type = pre_scores.dtype
id_type = pre_ids.dtype
inputs = {"pre_ids": pre_ids, "pre_scores": pre_scores, "scores": scores}
if ids is not None:
inputs["ids"] = ids
selected_scores = helper.create_variable_for_type_inference(
dtype=score_type)
@ -3961,12 +3971,7 @@ def beam_search(pre_ids,
helper.append_op(
type='beam_search',
inputs={
'pre_ids': pre_ids,
'pre_scores': pre_scores,
'ids': ids,
'scores': scores,
},
inputs=inputs,
outputs={
'selected_ids': selected_ids,
'selected_scores': selected_scores,
@ -3976,6 +3981,7 @@ def beam_search(pre_ids,
'level': level,
'beam_size': beam_size,
'end_id': end_id,
'is_accumulated': is_accumulated,
})
return selected_ids, selected_scores

Loading…
Cancel
Save