change data type of beam_search op (#7374)

detection_output_fixbug
Qiao Longfei 7 years ago committed by GitHub
parent 91f80f792d
commit efe06caa3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,7 +39,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
std::map<size_t /*offset*/, std::vector<Item>> hash;
framework::LoD new_lod;
auto *ids_data = selected_ids->mutable_data<int>(platform::CPUPlace());
auto *ids_data = selected_ids->mutable_data<int64_t>(platform::CPUPlace());
auto *scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace());
@ -66,7 +66,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) {
auto *pre_ids_data = pre_ids.data<int>();
auto *pre_ids_data = pre_ids.data<int64_t>();
for (size_t offset = 0; offset < items->size(); offset++) {
auto prefix_id = pre_ids_data[offset];
@ -127,7 +127,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
auto abs_lod = framework::ToAbsOffset(ids.lod());
PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL);
auto *ids_data = ids.data<int>();
auto *ids_data = ids.data<int64_t>();
auto *scores_data = scores.data<float>();
size_t instance_dim = 1;

@ -37,13 +37,13 @@ class BeamSearchOpTester(unittest.TestCase):
print 'lod', selected_ids.lod()
def _create_pre_ids(self):
np_data = np.array([[1, 2, 3, 4]], dtype='int32')
np_data = np.array([[1, 2, 3, 4]], dtype='int64')
tensor = create_tensor(self.scope, "pre_ids", np_data)
def _create_ids(self):
self.lod = [[0, 1, 4], [0, 1, 2, 3, 4]]
np_data = np.array(
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int32')
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod)

Loading…
Cancel
Save