|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -109,7 +110,7 @@ class Code {
|
|
|
|
|
// set a CodeTable interface to create multiple code table
|
|
|
|
|
class CodeTable {
|
|
|
|
|
public:
|
|
|
|
|
virtual std::unique_ptr<Code> get_code(int64_t code) const = 0;
|
|
|
|
|
virtual Code* get_code(int64_t code) const = 0;
|
|
|
|
|
virtual size_t size() const = 0;
|
|
|
|
|
virtual int get_max_code_length() const = 0;
|
|
|
|
|
virtual ~CodeTable() {}
|
|
|
|
@ -180,14 +181,23 @@ class SimpleCodeTable : public CodeTable {
|
|
|
|
|
public:
|
|
|
|
|
SimpleCodeTable(size_t num_classes, const int64_t* ids)
|
|
|
|
|
: num_classes_(num_classes), ids_(ids) {}
|
|
|
|
|
std::unique_ptr<Code> get_code(int64_t code) const {
|
|
|
|
|
std::unique_ptr<Code> coder(new SimpleCode(code, num_classes_, ids_));
|
|
|
|
|
return coder;
|
|
|
|
|
|
|
|
|
|
Code* get_code(int64_t code) const {
|
|
|
|
|
auto it = codes_.find(code);
|
|
|
|
|
if (it != codes_.end()) {
|
|
|
|
|
return it->second.get();
|
|
|
|
|
}
|
|
|
|
|
auto* result = new SimpleCode(code, num_classes_, ids_);
|
|
|
|
|
codes_.emplace(code, std::unique_ptr<Code>(result));
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t size() const { return num_classes_; }
|
|
|
|
|
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
mutable std::map<int64_t, std::unique_ptr<Code>> codes_;
|
|
|
|
|
|
|
|
|
|
size_t num_classes_;
|
|
|
|
|
const int64_t* ids_;
|
|
|
|
|
};
|
|
|
|
@ -199,9 +209,14 @@ class CustomCodeTable : public CodeTable {
|
|
|
|
|
const framework::Tensor& pcode, const int64_t* ids)
|
|
|
|
|
: ptable_(ptable), pcode_(pcode), ids_(ids) {}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Code> get_code(int64_t code) const {
|
|
|
|
|
std::unique_ptr<Code> coder(new CustomCode<T>(ptable_, pcode_, ids_, code));
|
|
|
|
|
return coder;
|
|
|
|
|
Code* get_code(int64_t code) const {
|
|
|
|
|
auto it = codes_.find(code);
|
|
|
|
|
if (it != codes_.end()) {
|
|
|
|
|
return it->second.get();
|
|
|
|
|
}
|
|
|
|
|
auto* result = new CustomCode<T>(ptable_, pcode_, ids_, code);
|
|
|
|
|
codes_.emplace(code, std::unique_ptr<Code>(result));
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
|
|
|
|
@ -210,6 +225,7 @@ class CustomCodeTable : public CodeTable {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
mutable std::unordered_map<int64_t, std::unique_ptr<Code>> codes_;
|
|
|
|
|
const framework::Tensor& ptable_;
|
|
|
|
|
const framework::Tensor& pcode_;
|
|
|
|
|
const int64_t* ids_;
|
|
|
|
|