|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -22,6 +23,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/tensor.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
#include "paddle/fluid/platform/variant.h"
|
|
|
|
|
|
|
|
|
|
#if defined(_WIN32)
|
|
|
|
|
#include <intrin.h>
|
|
|
|
@ -98,24 +100,7 @@ inline int clz(const T& value) {
|
|
|
|
|
|
|
|
|
|
inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); }
|
|
|
|
|
#endif // !_WIN32
|
|
|
|
|
// set a code interface to create multiple code
|
|
|
|
|
class Code {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~Code() {}
|
|
|
|
|
virtual size_t calc_index(int bit) const = 0;
|
|
|
|
|
virtual bool calc_bit(int bit) const = 0;
|
|
|
|
|
virtual int get_length() const = 0;
|
|
|
|
|
};
|
|
|
|
|
// set a CodeTable interface to create multiple code table
|
|
|
|
|
class CodeTable {
|
|
|
|
|
public:
|
|
|
|
|
virtual std::unique_ptr<Code> get_code(int64_t code) const = 0;
|
|
|
|
|
virtual size_t size() const = 0;
|
|
|
|
|
virtual int get_max_code_length() const = 0;
|
|
|
|
|
virtual ~CodeTable() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SimpleCode : public Code {
|
|
|
|
|
class SimpleCode {
|
|
|
|
|
public:
|
|
|
|
|
SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
|
|
|
|
|
: c_(static_cast<size_t>(ids[code]) + num_classes) {}
|
|
|
|
@ -137,16 +122,16 @@ class SimpleCode : public Code {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CustomCode : public Code {
|
|
|
|
|
class CustomCode {
|
|
|
|
|
public:
|
|
|
|
|
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
|
|
|
|
|
const int64_t* ids, int index)
|
|
|
|
|
: ids_(ids), index_(index) {
|
|
|
|
|
ptable_ = ptable.Slice(index, index + 1);
|
|
|
|
|
pcode_ = pcode.Slice(index, index + 1);
|
|
|
|
|
const int64_t* ids, int index) {
|
|
|
|
|
seq_len_ = ptable.dims()[1];
|
|
|
|
|
ptable_data_ = ptable.data<T>() + seq_len_ * index;
|
|
|
|
|
pcode_data_ = pcode.data<T>() + seq_len_ * index;
|
|
|
|
|
}
|
|
|
|
|
/**
|
|
|
|
|
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
|
|
|
|
|
* Here the id of root should be 1 rather than 0, thus the encoding of class c
|
|
|
|
|
* is `c + num_classes` and all siblings can get the same weight indice using
|
|
|
|
|
* prefixes.
|
|
|
|
|
* Weight index is the prefixes of encoding, thus leave out the right most
|
|
|
|
@ -154,36 +139,37 @@ class CustomCode : public Code {
|
|
|
|
|
* Binary classification path is the suffixes of encoding, thus leave out the
|
|
|
|
|
* left most bit in calc_bit.
|
|
|
|
|
*/
|
|
|
|
|
size_t calc_index(int bit) const { return ptable_.data<T>()[bit]; }
|
|
|
|
|
bool calc_bit(int bit) const { return pcode_.data<T>()[bit]; }
|
|
|
|
|
int get_length() const {
|
|
|
|
|
int length = 0;
|
|
|
|
|
size_t calc_index(int bit) const { return ptable_data_[bit]; }
|
|
|
|
|
bool calc_bit(int bit) const { return pcode_data_[bit]; }
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < static_cast<int>(ptable_.dims()[1]); i++) {
|
|
|
|
|
if (ptable_.data<T>()[i] >= 0) {
|
|
|
|
|
length++;
|
|
|
|
|
} else {
|
|
|
|
|
return length;
|
|
|
|
|
}
|
|
|
|
|
// NOTE: this function is not thread-safe.
|
|
|
|
|
int get_length() const {
|
|
|
|
|
if (length_ < 0) {
|
|
|
|
|
auto len = seq_len_;
|
|
|
|
|
length_ =
|
|
|
|
|
static_cast<int>(std::find_if(ptable_data_, ptable_data_ + len,
|
|
|
|
|
[](const T& val) { return val < 0; }) -
|
|
|
|
|
ptable_data_);
|
|
|
|
|
}
|
|
|
|
|
return length;
|
|
|
|
|
return length_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
framework::Tensor ptable_;
|
|
|
|
|
framework::Tensor pcode_;
|
|
|
|
|
const int64_t* ids_;
|
|
|
|
|
const int index_;
|
|
|
|
|
int64_t seq_len_;
|
|
|
|
|
const T* ptable_data_;
|
|
|
|
|
const T* pcode_data_;
|
|
|
|
|
mutable int length_{-1};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SimpleCodeTable : public CodeTable {
|
|
|
|
|
class SimpleCodeTable {
|
|
|
|
|
public:
|
|
|
|
|
SimpleCodeTable(size_t num_classes, const int64_t* ids)
|
|
|
|
|
: num_classes_(num_classes), ids_(ids) {}
|
|
|
|
|
std::unique_ptr<Code> get_code(int64_t code) const {
|
|
|
|
|
std::unique_ptr<Code> coder(new SimpleCode(code, num_classes_, ids_));
|
|
|
|
|
return coder;
|
|
|
|
|
|
|
|
|
|
SimpleCode get_code(int64_t code) const {
|
|
|
|
|
return SimpleCode(code, num_classes_, ids_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t size() const { return num_classes_; }
|
|
|
|
|
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
|
|
|
|
|
|
|
|
|
@ -193,15 +179,14 @@ class SimpleCodeTable : public CodeTable {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CustomCodeTable : public CodeTable {
|
|
|
|
|
class CustomCodeTable {
|
|
|
|
|
public:
|
|
|
|
|
CustomCodeTable(const framework::Tensor& ptable,
|
|
|
|
|
const framework::Tensor& pcode, const int64_t* ids)
|
|
|
|
|
: ptable_(ptable), pcode_(pcode), ids_(ids) {}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Code> get_code(int64_t code) const {
|
|
|
|
|
std::unique_ptr<Code> coder(new CustomCode<T>(ptable_, pcode_, ids_, code));
|
|
|
|
|
return coder;
|
|
|
|
|
CustomCode<T> get_code(int64_t code) const {
|
|
|
|
|
return CustomCode<T>(ptable_, pcode_, ids_, code);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
|
|
|
|
@ -215,19 +200,21 @@ class CustomCodeTable : public CodeTable {
|
|
|
|
|
const int64_t* ids_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
using CodeTable = boost::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MatrixBitCodeFunctor {
|
|
|
|
|
public:
|
|
|
|
|
MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
|
|
|
|
|
: num_classes_(num_classes),
|
|
|
|
|
ids_(ids),
|
|
|
|
|
code_table_(new SimpleCodeTable(num_classes, ids)) {}
|
|
|
|
|
code_table_(SimpleCodeTable(num_classes, ids)) {}
|
|
|
|
|
|
|
|
|
|
MatrixBitCodeFunctor(const framework::Tensor& ptable,
|
|
|
|
|
const framework::Tensor& pcode, const int64_t* ids)
|
|
|
|
|
: num_classes_(static_cast<size_t>(ptable.dims()[1])),
|
|
|
|
|
ids_(ids),
|
|
|
|
|
code_table_(new CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
|
|
|
|
|
code_table_(CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
|
|
|
|
|
/* For j < code_length
|
|
|
|
|
tmat(i, j) += vec(0, index(i, j))
|
|
|
|
|
*/
|
|
|
|
@ -277,7 +264,7 @@ class MatrixBitCodeFunctor {
|
|
|
|
|
|
|
|
|
|
size_t num_classes_;
|
|
|
|
|
const int64_t* ids_;
|
|
|
|
|
std::unique_ptr<CodeTable> code_table_;
|
|
|
|
|
CodeTable code_table_;
|
|
|
|
|
};
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|