|
|
|
@ -13,52 +13,57 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <queue>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
static constexpr int TopKPosPaddingId = -1;
|
|
|
|
|
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void get_topk_pos(const T* data, int length, int k, int* pos) {
|
|
|
|
|
size_t real_k = k < length ? k : length;
|
|
|
|
|
|
|
|
|
|
std::vector<T> v(data, data + length);
|
|
|
|
|
|
|
|
|
|
std::vector<int> topk_pos;
|
|
|
|
|
T min_val = std::numeric_limits<T>::lowest();
|
|
|
|
|
while (topk_pos.size() < real_k) {
|
|
|
|
|
T max_val = min_val;
|
|
|
|
|
int max_pos = -1;
|
|
|
|
|
for (int i = 0; i < length; ++i) {
|
|
|
|
|
if (v[i] > max_val) {
|
|
|
|
|
max_pos = i;
|
|
|
|
|
max_val = v[i];
|
|
|
|
|
static void get_topk_pos(const T* data, int length, int k, int* pos) {
|
|
|
|
|
VLOG(3) << "length: " << length << " , k : " << k;
|
|
|
|
|
|
|
|
|
|
std::priority_queue<std::pair<T, int>, std::vector<std::pair<T, int>>,
|
|
|
|
|
std::greater<std::pair<T, int>>>
|
|
|
|
|
topk_queue;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < length; ++i) {
|
|
|
|
|
T elem = data[i];
|
|
|
|
|
if (topk_queue.size() < static_cast<size_t>(k)) {
|
|
|
|
|
topk_queue.emplace(elem, i);
|
|
|
|
|
} else {
|
|
|
|
|
if (elem >= topk_queue.top().first) {
|
|
|
|
|
// replace top node if found a bigger value
|
|
|
|
|
topk_queue.pop();
|
|
|
|
|
topk_queue.emplace(elem, i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert(max_pos >= 0);
|
|
|
|
|
|
|
|
|
|
topk_pos.push_back(max_pos);
|
|
|
|
|
v[max_pos] = min_val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert(topk_pos.size() > 0);
|
|
|
|
|
while (topk_pos.size() < (size_t)k) {
|
|
|
|
|
topk_pos.push_back(-1);
|
|
|
|
|
// reversely assign value
|
|
|
|
|
int real_k = topk_queue.size();
|
|
|
|
|
for (int i = real_k - 1; i >= 0; --i) {
|
|
|
|
|
pos[i] = topk_queue.top().second;
|
|
|
|
|
topk_queue.pop();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < topk_pos.size(); ++i) {
|
|
|
|
|
pos[i] = topk_pos[i];
|
|
|
|
|
// if length of data is less than k, fill TopKPosPaddingId at the end of pos.
|
|
|
|
|
for (int i = real_k; i < k; ++i) {
|
|
|
|
|
pos[i] = TopKPosPaddingId;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
} // namespace details
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
|
|
|
|
@ -70,20 +75,29 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* out = context.Output<LoDTensor>("Out");
|
|
|
|
|
auto* pos = context.Output<Tensor>("pos");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in->lod().empty(), false,
|
|
|
|
|
"Input(X) Tensor of SequenceTopkAvgPoolingOp does not "
|
|
|
|
|
"contain LoD information.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(row->lod().empty(), false,
|
|
|
|
|
"Input(ROW) Tensor of SequenceTopkAvgPoolingOp does not "
|
|
|
|
|
"contain LoD information.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(col->lod().empty(), false,
|
|
|
|
|
"Input(COLUMN) Tensor of SequenceTopkAvgPoolingOp does "
|
|
|
|
|
"not contain LoD information.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in->lod().empty(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) Tensor of SequenceTopkAvgPoolingOp does not "
|
|
|
|
|
"contain LoD information."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
row->lod().empty(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(ROW) Tensor of SequenceTopkAvgPoolingOp does not "
|
|
|
|
|
"contain LoD information."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
col->lod().empty(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(COLUMN) Tensor of SequenceTopkAvgPoolingOp does "
|
|
|
|
|
"not contain LoD information."));
|
|
|
|
|
|
|
|
|
|
auto channel_num = context.Attr<int>("channel_num");
|
|
|
|
|
auto topks = context.Attr<std::vector<int>>("topks");
|
|
|
|
|
auto k_num = topks.size();
|
|
|
|
|
auto max_k = topks[topks.size() - 1];
|
|
|
|
|
PADDLE_ENFORCE_GE(max_k, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected max_k >= 0, but received %d.", max_k));
|
|
|
|
|
std::vector<int> vec_pos_shape;
|
|
|
|
|
auto in_lod = in->lod()[0];
|
|
|
|
|
|
|
|
|
@ -116,7 +130,10 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
|
|
|
|
|
int row_size = row_lod[i + 1] - row_lod[i];
|
|
|
|
|
int col_size = col_lod[i + 1] - col_lod[i];
|
|
|
|
|
PADDLE_ENFORCE_EQ(total_size, channel_num * row_size * col_size,
|
|
|
|
|
"size wrong in sequence_topk_avg_pooling_op!");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Expected total_size == channel_num * row_size * "
|
|
|
|
|
"col_size, but got %d != %d.",
|
|
|
|
|
total_size, channel_num * row_size * col_size));
|
|
|
|
|
|
|
|
|
|
int feature_num = row_size * col_size;
|
|
|
|
|
for (int j = 0; j < channel_num; ++j) {
|
|
|
|
@ -130,14 +147,14 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto out_slice_data = dout_data + row_lod[i] * channel_num * k_num +
|
|
|
|
|
r * channel_num * k_num + j * k_num;
|
|
|
|
|
|
|
|
|
|
get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data);
|
|
|
|
|
if (pos_slice_data[0] == -1) {
|
|
|
|
|
details::get_topk_pos<T>(row_data, col_size, max_k, pos_slice_data);
|
|
|
|
|
if (pos_slice_data[0] == TopKPosPaddingId) {
|
|
|
|
|
sum_data[0] = 0.0;
|
|
|
|
|
} else {
|
|
|
|
|
sum_data[0] = row_data[pos_slice_data[0]];
|
|
|
|
|
}
|
|
|
|
|
for (int k = 1; k < max_k; ++k) {
|
|
|
|
|
if (pos_slice_data[k] == -1) {
|
|
|
|
|
if (pos_slice_data[k] == TopKPosPaddingId) {
|
|
|
|
|
sum_data[k] = sum_data[k - 1];
|
|
|
|
|
} else {
|
|
|
|
|
sum_data[k] = sum_data[k - 1] + row_data[pos_slice_data[k]];
|
|
|
|
@ -206,7 +223,7 @@ class SequenceTopkAvgPoolingGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
for (size_t m = 0; m < k_num; ++m) {
|
|
|
|
|
for (int k = 0; k < topks[m]; ++k) {
|
|
|
|
|
if (pos_slice_data[k] == -1) {
|
|
|
|
|
if (pos_slice_data[k] == TopKPosPaddingId) {
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
in_slice_data[pos_slice_data[k]] += row_data[m] / topks[m];
|
|
|
|
|