local_add_cudnn_lstm
JiabinYang 6 years ago
parent a507845a77
commit ba9ff508e8

@ -119,6 +119,33 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
}
}
// template <typename T>
// void MatrixBitCodeFunctor<T>::MulGradSparseWeight(const framework::Tensor&
// tmat,
// framework::SelectedRows* weight,
// const framework::Tensor& input) {
// size_t num_samples = tmat.dims()[0];
// size_t input_width = input.dims()[1];
// size_t tmat_width = tmat.dims()[1];
// size_t weight_width = weight->dims()[1];
// auto tmat_value = tmat.data<T>();
// auto weight_value = weight->data<T>();
// auto input_value = input.data<T>();
// for (size_t i = 0; i < num_samples; ++i) {
// auto code = code_table->get_code(i);
// int code_length = code->get_length();
// for (int j = 0; j < code_length; ++j) {
// // size_t index = code->calc_index(j);
// for (size_t k = 0; k < input_width; ++k) {
// weight_value[j * weight_width + k] +=
// tmat_value[i * tmat_width + j] * input_value[input_width * i +
// k];
// }
// }
// }
// }
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight,

Loading…
Cancel
Save