|
|
|
@ -89,6 +89,8 @@ template <typename T>
|
|
|
|
|
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
|
|
|
|
|
const framework::Tensor& weight,
|
|
|
|
|
const framework::Tensor& input) {
|
|
|
|
|
auto blas =
|
|
|
|
|
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
|
|
|
|
|
size_t num_samples = tmat->dims()[0];
|
|
|
|
|
size_t tmat_width = tmat->dims()[1];
|
|
|
|
|
size_t input_width = input.dims()[1];
|
|
|
|
@ -99,13 +101,12 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
|
|
|
|
|
for (size_t i = 0; i < num_samples; ++i) {
|
|
|
|
|
auto code = code_table_->get_code(i);
|
|
|
|
|
int code_length = code->get_length();
|
|
|
|
|
const T* input_row = input_value + input_width * i;
|
|
|
|
|
for (int j = 0; j < code_length; ++j) {
|
|
|
|
|
size_t index = code->calc_index(j);
|
|
|
|
|
const T* weight_row = weight_value + weight_width * index;
|
|
|
|
|
T sum = static_cast<T>(0.0);
|
|
|
|
|
for (size_t k = 0; k < input_width; ++k) {
|
|
|
|
|
sum += weight_value[weight_width * index + k] *
|
|
|
|
|
input_value[input_width * i + k];
|
|
|
|
|
}
|
|
|
|
|
sum = blas.DOT(input_width, weight_row, input_row);
|
|
|
|
|
tmat_value[i * tmat_width + j] += sum;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -115,6 +116,8 @@ template <typename T>
|
|
|
|
|
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
|
|
|
|
|
framework::Tensor* weight,
|
|
|
|
|
const framework::Tensor& input) {
|
|
|
|
|
auto blas =
|
|
|
|
|
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
|
|
|
|
|
size_t num_samples = tmat.dims()[0];
|
|
|
|
|
size_t input_width = input.dims()[1];
|
|
|
|
|
size_t tmat_width = tmat.dims()[1];
|
|
|
|
@ -122,16 +125,25 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
|
|
|
|
|
auto tmat_value = tmat.data<T>();
|
|
|
|
|
auto weight_value = weight->data<T>();
|
|
|
|
|
auto input_value = input.data<T>();
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < num_samples; ++i) {
|
|
|
|
|
auto code = code_table_->get_code(i);
|
|
|
|
|
int code_length = code->get_length();
|
|
|
|
|
const T* input_value_row = input_value + input_width * i;
|
|
|
|
|
const T* tmat_row = tmat_value + i * tmat_width;
|
|
|
|
|
for (int j = 0; j < code_length; ++j) {
|
|
|
|
|
size_t index = code->calc_index(j);
|
|
|
|
|
|
|
|
|
|
for (size_t k = 0; k < input_width; ++k) {
|
|
|
|
|
weight_value[weight_width * index + k] +=
|
|
|
|
|
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
|
|
|
|
|
}
|
|
|
|
|
ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto& op : ops) {
|
|
|
|
|
auto& op_in_row = op.second;
|
|
|
|
|
for (auto& pair : op_in_row) {
|
|
|
|
|
auto& scale = pair.first;
|
|
|
|
|
auto* input_row = pair.second;
|
|
|
|
|
T* weight_row = weight_value + op.first * weight_width;
|
|
|
|
|
blas.AXPY(input_width, scale, input_row, weight_row);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -140,6 +152,8 @@ template <typename T>
|
|
|
|
|
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
|
|
|
|
|
framework::SelectedRows* weight,
|
|
|
|
|
const framework::Tensor& input) {
|
|
|
|
|
auto blas =
|
|
|
|
|
GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
|
|
|
|
|
size_t num_samples = tmat.dims()[0];
|
|
|
|
|
size_t input_width = input.dims()[1];
|
|
|
|
|
size_t tmat_width = tmat.dims()[1];
|
|
|
|
@ -147,17 +161,28 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
|
|
|
|
|
auto tmat_value = tmat.data<T>();
|
|
|
|
|
auto weight_value = weight->mutable_value()->data<T>();
|
|
|
|
|
auto input_value = input.data<T>();
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
|
|
|
|
|
ops.reserve(weight->rows().size());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < num_samples; ++i) {
|
|
|
|
|
auto code = code_table_->get_code(i);
|
|
|
|
|
int code_length = code->get_length();
|
|
|
|
|
const T* input_value_row = input_value + input_width * i;
|
|
|
|
|
const T* tmat_row = tmat_value + i * tmat_width;
|
|
|
|
|
for (int j = 0; j < code_length; ++j) {
|
|
|
|
|
size_t index = code->calc_index(j);
|
|
|
|
|
for (size_t k = 0; k < input_width; ++k) {
|
|
|
|
|
int64_t row_index = weight->GetIndexFromId(static_cast<int64_t>(index));
|
|
|
|
|
weight_value[row_index * weight_width + k] +=
|
|
|
|
|
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
|
|
|
|
|
}
|
|
|
|
|
ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& row : weight->rows()) {
|
|
|
|
|
auto& op_in_row = ops[row];
|
|
|
|
|
for (auto& pair : op_in_row) {
|
|
|
|
|
auto& scale = pair.first;
|
|
|
|
|
auto* input_row = pair.second;
|
|
|
|
|
blas.AXPY(input_width, scale, input_row, weight_value);
|
|
|
|
|
}
|
|
|
|
|
weight_value += weight_width;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|