|
|
|
@ -50,7 +50,7 @@ namespace math {
|
|
|
|
|
for j < codeLength:
|
|
|
|
|
op(a(i, j), b(0, index(i, j)))
|
|
|
|
|
*/
|
|
|
|
|
template <class CodeTable, class Op, typename T, typename Place>
|
|
|
|
|
template <class CodeTable, class Op, typename T>
|
|
|
|
|
static void AddByBitCodeT(Op op, CodeTable code_table,
|
|
|
|
|
const framework::Tensor& codes, framework::Tensor& a,
|
|
|
|
|
framework::Tensor& b) {
|
|
|
|
@ -72,11 +72,11 @@ static void AddByBitCodeT(Op op, CodeTable code_table,
|
|
|
|
|
/* For j < codeLength:
|
|
|
|
|
a(i, j) += b(0, index(i, j))
|
|
|
|
|
*/
|
|
|
|
|
template <typename T, typename Place>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void AddByBitCode(size_t num_classes, const framework::Tensor& codes,
|
|
|
|
|
framework::Tensor& a, const framework::Tensor& b) {
|
|
|
|
|
auto op = [](T& t, T& v) { t += v; };
|
|
|
|
|
AddByBitCodeT<T, Place>(op, SimpleCodeTable(num_classes), codes, a, b);
|
|
|
|
|
AddByBitCodeT<T>(op, SimpleCodeTable(num_classes), codes, a, b);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|