|
|
|
@ -124,11 +124,12 @@ class SimpleCode {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CustomCode {
|
|
|
|
|
public:
|
|
|
|
|
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
|
|
|
|
|
const int64_t* ids, int index) {
|
|
|
|
|
seq_len_ = ptable.dims()[1];
|
|
|
|
|
ptable_data_ = ptable.data<T>() + seq_len_ * index;
|
|
|
|
|
pcode_data_ = pcode.data<T>() + seq_len_ * index;
|
|
|
|
|
CustomCode(const framework::Tensor& path_table,
|
|
|
|
|
const framework::Tensor& path_code, const int64_t* ids,
|
|
|
|
|
int index) {
|
|
|
|
|
seq_len_ = path_table.dims()[1];
|
|
|
|
|
path_table_data_ = path_table.data<T>() + seq_len_ * index;
|
|
|
|
|
path_code_data_ = path_code.data<T>() + seq_len_ * index;
|
|
|
|
|
}
|
|
|
|
|
/**
|
|
|
|
|
* Here the id of root should be 1 rather than 0, thus the encoding of class c
|
|
|
|
@ -139,25 +140,25 @@ class CustomCode {
|
|
|
|
|
* Binary classification path is the suffixes of encoding, thus leave out the
|
|
|
|
|
* left most bit in calc_bit.
|
|
|
|
|
*/
|
|
|
|
|
size_t calc_index(int bit) const { return ptable_data_[bit]; }
|
|
|
|
|
bool calc_bit(int bit) const { return pcode_data_[bit]; }
|
|
|
|
|
size_t calc_index(int bit) const { return path_table_data_[bit]; }
|
|
|
|
|
bool calc_bit(int bit) const { return path_code_data_[bit]; }
|
|
|
|
|
|
|
|
|
|
// NOTE: this function is not thread-safe.
|
|
|
|
|
int get_length() const {
|
|
|
|
|
if (length_ < 0) {
|
|
|
|
|
auto len = seq_len_;
|
|
|
|
|
length_ =
|
|
|
|
|
static_cast<int>(std::find_if(ptable_data_, ptable_data_ + len,
|
|
|
|
|
[](const T& val) { return val < 0; }) -
|
|
|
|
|
ptable_data_);
|
|
|
|
|
length_ = static_cast<int>(
|
|
|
|
|
std::find_if(path_table_data_, path_table_data_ + len,
|
|
|
|
|
[](const T& val) { return val < 0; }) -
|
|
|
|
|
path_table_data_);
|
|
|
|
|
}
|
|
|
|
|
return length_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int64_t seq_len_;
|
|
|
|
|
const T* ptable_data_;
|
|
|
|
|
const T* pcode_data_;
|
|
|
|
|
const T* path_table_data_;
|
|
|
|
|
const T* path_code_data_;
|
|
|
|
|
mutable int length_{-1};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -214,7 +215,7 @@ class MatrixBitCodeFunctor {
|
|
|
|
|
const framework::Tensor& path_code, const int64_t* ids)
|
|
|
|
|
: num_classes_(static_cast<size_t>(path_table.dims()[1])),
|
|
|
|
|
ids_(ids),
|
|
|
|
|
code_table_(CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
|
|
|
|
|
code_table_(CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
|
|
|
|
|
/* For j < code_length
|
|
|
|
|
tmat(i, j) += vec(0, index(i, j))
|
|
|
|
|
*/
|
|
|
|
|