feature/beam search op (#5052)
parent
7c3ec22081
commit
09866fb75f
@ -0,0 +1,185 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/beam_search_op.h"
|
||||
|
||||
#include <map>
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
|
||||
framework::LoDTensor *selected_ids,
|
||||
framework::LoDTensor *selected_scores) {
|
||||
auto items = SelectTopBeamSizeItems();
|
||||
auto selected_items = ToMap(items);
|
||||
PruneEndidCandidates(pre_ids, &selected_items);
|
||||
// calculate the output tensor's height
|
||||
size_t num_instances = std::accumulate(
|
||||
std::begin(items), std::end(items), 0,
|
||||
[](size_t a, std::vector<Item> &b) { return a + b.size(); });
|
||||
// the output tensor shape should be [num_instances, 1]
|
||||
auto dims = framework::make_ddim(
|
||||
std::vector<int64_t>({static_cast<int>(num_instances), 1}));
|
||||
selected_ids->Resize(dims);
|
||||
selected_scores->Resize(dims);
|
||||
|
||||
std::map<size_t /*offset*/, std::vector<Item>> hash;
|
||||
framework::LoD new_lod;
|
||||
auto *ids_data = selected_ids->mutable_data<int>(platform::CPUPlace());
|
||||
auto *scores_data =
|
||||
selected_scores->mutable_data<float>(platform::CPUPlace());
|
||||
|
||||
// fill in data
|
||||
std::vector<size_t> low_level;
|
||||
size_t low_offset = 0;
|
||||
for (auto &items : selected_items) {
|
||||
low_level.push_back(low_offset);
|
||||
for (auto &item : items) {
|
||||
ids_data[low_offset] = item.id;
|
||||
scores_data[low_offset] = item.score;
|
||||
low_offset++;
|
||||
}
|
||||
}
|
||||
// fill lod
|
||||
auto abs_lod = framework::ToAbsOffset(ids_->lod());
|
||||
auto &high_level = abs_lod[lod_level_];
|
||||
framework::LoD lod(2);
|
||||
lod[0].assign(high_level.begin(), high_level.end());
|
||||
lod[1].assign(low_level.begin(), low_level.end());
|
||||
selected_ids->set_lod(lod);
|
||||
selected_scores->set_lod(lod);
|
||||
}
|
||||
|
||||
void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
|
||||
std::vector<std::vector<Item>> *items) {
|
||||
auto *pre_ids_data = pre_ids.data<int>();
|
||||
|
||||
for (size_t offset = 0; offset < items->size(); offset++) {
|
||||
auto prefix_id = pre_ids_data[offset];
|
||||
if (prefix_id == end_id_) {
|
||||
items->at(offset).clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
|
||||
const std::vector<std::vector<Item>> &items) {
|
||||
std::vector<std::vector<Item>> result;
|
||||
for (auto &entries : items) {
|
||||
for (const auto &item : entries) {
|
||||
if (item.offset >= result.size()) {
|
||||
result.resize(item.offset + 1);
|
||||
}
|
||||
result[item.offset].push_back(item);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::vector<BeamSearch::Item>>
|
||||
BeamSearch::SelectTopBeamSizeItems() {
|
||||
std::vector<std::vector<Item>> result;
|
||||
std::vector<Item> items;
|
||||
// for each source sentence, select the top beam_size items across all
|
||||
// candidate sets.
|
||||
while (NextItemSet(&items)) {
|
||||
std::nth_element(std::begin(items), std::begin(items) + beam_size_,
|
||||
std::end(items), [](const Item &a, const Item &b) {
|
||||
// TODO(superjom) make score's comparation customizable.
|
||||
// partial sort in descending order
|
||||
return a.score > b.score;
|
||||
});
|
||||
// prune the top beam_size items.
|
||||
if (items.size() > beam_size_) {
|
||||
items.resize(beam_size_);
|
||||
}
|
||||
result.emplace_back(items);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// the candidates of a source
|
||||
bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
|
||||
if (sent_offset_ >= ids_->NumElements(lod_level_)) {
|
||||
return false;
|
||||
}
|
||||
// find the current candidates
|
||||
auto ids = *ids_;
|
||||
auto scores = *scores_;
|
||||
|
||||
auto source_abs_two_level_lod = framework::SliceInLevel(
|
||||
ids.lod(), lod_level_, sent_offset_, sent_offset_ + 1);
|
||||
source_abs_two_level_lod = framework::ToAbsOffset(source_abs_two_level_lod);
|
||||
auto abs_lod = framework::ToAbsOffset(ids.lod());
|
||||
PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL);
|
||||
|
||||
auto *ids_data = ids.data<int>();
|
||||
auto *scores_data = scores.data<float>();
|
||||
|
||||
size_t instance_dim = 1;
|
||||
for (int i = 1; i < ids.dims().size(); i++) {
|
||||
instance_dim *= ids.dims()[i];
|
||||
}
|
||||
|
||||
items->clear();
|
||||
items->reserve(framework::product(ids.dims()));
|
||||
for (size_t offset = abs_lod[lod_level_][sent_offset_];
|
||||
offset < abs_lod[lod_level_][sent_offset_ + 1]; offset++) {
|
||||
for (int d = 0; d < instance_dim; d++) {
|
||||
const size_t dim_offset = offset * instance_dim + d;
|
||||
items->emplace_back(offset, ids_data[dim_offset],
|
||||
scores_data[dim_offset]);
|
||||
}
|
||||
}
|
||||
|
||||
sent_offset_++;
|
||||
return true;
|
||||
}
|
||||
|
||||
class BeamSearchProtoAndCheckerMaker
|
||||
: public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
BeamSearchProtoAndCheckerMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
// inputs and outputs stored in proto
|
||||
AddInput("pre_ids", "ids in previous step");
|
||||
AddInput("ids", "a LoDTensor of shape of [None,k]");
|
||||
AddInput("scores",
|
||||
"a LoDTensor that has the same shape and LoD with `ids`");
|
||||
AddOutput("selected_ids",
|
||||
"a LoDTensor that stores the IDs selected by beam search");
|
||||
AddOutput(
|
||||
"selected_scores",
|
||||
"a LoDTensor that has the same shape and LoD with `selected_ids`");
|
||||
|
||||
// Attributes stored in AttributeMap
|
||||
AddAttr<int>("level", "the level of LoDTensor");
|
||||
AddAttr<int>("beam_size", "beam size for beam search");
|
||||
AddAttr<int>("end_id",
|
||||
"the token id which indicates the end of a sequence");
|
||||
|
||||
AddComment(
|
||||
"This is a beam search operator that help to generate sequences.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_WITHOUT_GRADIENT(beam_search, paddle::operators::BeamSearchOp,
|
||||
paddle::operators::BeamSearchProtoAndCheckerMaker);
|
@ -0,0 +1,226 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef PADDLE_WITH_TESTING
|
||||
#include "gtest/gtest.h"
|
||||
#endif
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
/*
|
||||
* This is an implementation of beam search.
|
||||
*
|
||||
* To explain the details, lets take machine translation task for example, in
|
||||
* this task, one source sentence is translated to multiple target sentences,
|
||||
* during this period, one sentence will be translated to multiple translation
|
||||
* prefixes(target sentence that have not ended), in each time step a prefix
|
||||
* will have some candidates, input the candidate ids and their corresponding
|
||||
* scores (probabilities), it will sort and select the top beam_size candidates
|
||||
* for each source sentence, and store the selected candidates's score and their
|
||||
* corresponding ids to LoDTensors.
|
||||
*
|
||||
* A detailed example:
|
||||
*
|
||||
* Input
|
||||
*
|
||||
* ids:
|
||||
* LoD (should have 2 levels)
|
||||
* first level: [0, 1, 4]
|
||||
* second level: [0, 1, 2, 3, 4]
|
||||
*
|
||||
* tensor's data
|
||||
* [
|
||||
* [4, 2, 5]
|
||||
* [2, 1, 3]
|
||||
* [3, 5, 2]
|
||||
* [8, 2, 1]
|
||||
* ]
|
||||
*
|
||||
* scores:
|
||||
* LoD same as `ids`
|
||||
* tensor's data
|
||||
* [
|
||||
* [0.5, 0.3, 0.2]
|
||||
* [0.6, 0.3, 0.1]
|
||||
* [0.9, 0.5, 0.1]
|
||||
* [0.7, 0.5, 0.1]
|
||||
* ]
|
||||
*
|
||||
* the inputs means that there are 2 source sentences to translate, and the
|
||||
* first source has 1 prefix, the second source has 2 prefix.
|
||||
*
|
||||
* lets assume beam size is 2, and the beam search's output should be
|
||||
* LoD
|
||||
* first level:
|
||||
* [0, 1, 2]
|
||||
* second level:
|
||||
* [0, 2, 4]
|
||||
*
|
||||
* tensor's data
|
||||
* [[
|
||||
* 0.5,
|
||||
* 0.3,
|
||||
* 0.9,
|
||||
* 0.7
|
||||
* ]]
|
||||
*
|
||||
* TODO all the prune operations should be in the beam search, so it is better
|
||||
* to split the beam search algorithm into a sequence of smaller operators, and
|
||||
* the prune operators can be inserted in this sequence.
|
||||
*/
|
||||
class BeamSearch {
|
||||
public:
|
||||
// TODO(superjom) make type customizable
|
||||
using id_t = size_t;
|
||||
using score_t = float;
|
||||
/*
|
||||
* Input the arguments that needed by this class.
|
||||
*/
|
||||
BeamSearch(const framework::LoDTensor& ids,
|
||||
const framework::LoDTensor& scores, size_t level, size_t beam_size,
|
||||
int end_id)
|
||||
: beam_size_(beam_size),
|
||||
ids_(&ids),
|
||||
scores_(&scores),
|
||||
lod_level_(level),
|
||||
end_id_(end_id) {}
|
||||
|
||||
/*
|
||||
* The main function of beam search.
|
||||
*
|
||||
* @selected_ids: a [None, 1]-shaped tensor with LoD.
|
||||
* In a machine translation model, it might be the candidate term id sets,
|
||||
* each set stored as a varience-length sequence.
|
||||
* The format might be described with a two-level LoD
|
||||
* - [[0 1]
|
||||
* - [0 1 2]]
|
||||
* - [[]
|
||||
* - [0 1]]
|
||||
* the first level of LoD tells that there are two source sentences. The
|
||||
* second level describes the details of the candidate id set's offsets in
|
||||
* the
|
||||
* source sentences.
|
||||
*
|
||||
* @selected_scores: a LoD tensor with the same shape and LoD with
|
||||
* selected_ids.
|
||||
* It stores the corresponding scores of candidate ids in selected_ids.
|
||||
*
|
||||
* Return false if all the input tensor is empty, in machine translation task
|
||||
* that means no candidates is provided, and the task will stop running.
|
||||
*/
|
||||
void operator()(const framework::LoDTensor& pre_ids,
|
||||
framework::LoDTensor* selected_ids,
|
||||
framework::LoDTensor* selected_scores);
|
||||
|
||||
protected:
|
||||
/*
|
||||
* The basic items help to sort.
|
||||
*/
|
||||
struct Item {
|
||||
Item() {}
|
||||
Item(size_t offset, size_t id, float score)
|
||||
: offset(offset), id(id), score(score) {}
|
||||
// offset in the lod_level_+1
|
||||
size_t offset;
|
||||
// the candidate id
|
||||
id_t id;
|
||||
// the corresponding score
|
||||
score_t score;
|
||||
};
|
||||
|
||||
void PruneEndidCandidates(const framework::LoDTensor& pre_ids,
|
||||
std::vector<std::vector<Item>>* items);
|
||||
|
||||
/*
|
||||
* Transform the items into a map whose key is offset, value is the items.
|
||||
* NOTE low performance
|
||||
*/
|
||||
std::vector<std::vector<Item>> ToMap(
|
||||
const std::vector<std::vector<Item>>& inputs);
|
||||
|
||||
/*
|
||||
* For each source, select top beam_size records.
|
||||
*/
|
||||
std::vector<std::vector<Item>> SelectTopBeamSizeItems();
|
||||
|
||||
/*
|
||||
* Get the items of next source sequence, return false if no remaining items.
|
||||
*/
|
||||
bool NextItemSet(std::vector<Item>* items);
|
||||
|
||||
private:
|
||||
size_t beam_size_;
|
||||
const framework::LoDTensor* ids_;
|
||||
const framework::LoDTensor* scores_;
|
||||
size_t lod_level_{0};
|
||||
size_t sent_offset_{0};
|
||||
int end_id_{0};
|
||||
};
|
||||
|
||||
class BeamSearchOp : public framework::OperatorBase {
|
||||
public:
|
||||
BeamSearchOp(const std::string& type,
|
||||
const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
BeamSearchOp(const BeamSearchOp& o)
|
||||
: framework::OperatorBase(
|
||||
static_cast<const framework::OperatorBase&>(o)) {
|
||||
PADDLE_THROW("Not Implemented");
|
||||
}
|
||||
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const override {
|
||||
LOG(INFO) << "run beam search op";
|
||||
auto ids_var = scope.FindVar(Input("ids"));
|
||||
auto scores_var = scope.FindVar(Input("scores"));
|
||||
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
|
||||
PADDLE_ENFORCE_NOT_NULL(ids_var);
|
||||
PADDLE_ENFORCE_NOT_NULL(scores_var);
|
||||
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
|
||||
|
||||
auto& ids = ids_var->Get<framework::LoDTensor>();
|
||||
auto& scores = scores_var->Get<framework::LoDTensor>();
|
||||
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
|
||||
size_t level = Attr<int>("level");
|
||||
size_t beam_size = Attr<int>("beam_size");
|
||||
int end_id = Attr<int>("end_id");
|
||||
LOG(INFO) << "init beam search";
|
||||
BeamSearch alg(ids, scores, level, beam_size, end_id);
|
||||
|
||||
LOG(INFO) << "after beam search";
|
||||
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
|
||||
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
|
||||
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
|
||||
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
|
||||
auto& selected_ids_tensor =
|
||||
*selected_ids_var->GetMutable<framework::LoDTensor>();
|
||||
auto& selected_scores_tensor =
|
||||
*selected_scores_var->GetMutable<framework::LoDTensor>();
|
||||
LOG(INFO) << "run beam search";
|
||||
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
|
||||
LOG(INFO) << "finish beam search";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from paddle.v2.framework.op import Operator, DynamicRecurrentOp
|
||||
import paddle.v2.framework.core as core
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_tensor(scope, name, np_data):
|
||||
tensor = scope.var(name).get_tensor()
|
||||
tensor.set(np_data, core.CPUPlace())
|
||||
return tensor
|
||||
|
||||
|
||||
class BeamSearchOpTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.scope = core.Scope()
|
||||
self.ctx = core.DeviceContext.create(core.CPUPlace())
|
||||
self._create_ids()
|
||||
self._create_scores()
|
||||
self._create_pre_ids()
|
||||
self.scope.var('selected_ids')
|
||||
self.scope.var('selected_scores')
|
||||
|
||||
def test_run(self):
|
||||
op = Operator(
|
||||
'beam_search',
|
||||
pre_ids="pre_ids",
|
||||
ids='ids',
|
||||
scores='scores',
|
||||
selected_ids='selected_ids',
|
||||
selected_scores='selected_scores',
|
||||
level=0,
|
||||
beam_size=2,
|
||||
end_id=0, )
|
||||
op.run(self.scope, self.ctx)
|
||||
selected_ids = self.scope.find_var("selected_ids").get_tensor()
|
||||
print 'selected_ids', np.array(selected_ids)
|
||||
print 'lod', selected_ids.lod()
|
||||
|
||||
def _create_pre_ids(self):
|
||||
np_data = np.array([[1, 2, 3, 4]], dtype='int32')
|
||||
tensor = create_tensor(self.scope, "pre_ids", np_data)
|
||||
|
||||
def _create_ids(self):
|
||||
self.lod = [[0, 1, 4], [0, 1, 2, 3, 4]]
|
||||
np_data = np.array(
|
||||
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int32')
|
||||
tensor = create_tensor(self.scope, "ids", np_data)
|
||||
tensor.set_lod(self.lod)
|
||||
|
||||
def _create_scores(self):
|
||||
np_data = np.array(
|
||||
[
|
||||
[0.5, 0.3, 0.2],
|
||||
[0.6, 0.3, 0.1],
|
||||
[0.9, 0.5, 0.1],
|
||||
[0.7, 0.5, 0.1],
|
||||
],
|
||||
dtype='float32')
|
||||
tensor = create_tensor(self.scope, "scores", np_data)
|
||||
tensor.set_lod(self.lod)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue