local_add_cudnn_lstm
JiabinYang 6 years ago
parent e9be3366a9
commit 0fca16847c

@ -102,6 +102,8 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::LoDTensor& tmat,
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->dims()[1];
VLOG(30) << "sparse w_grad dims is [" << weight->dims()[0] << " ,"
<< weight->dims()[1] << " ]";
auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
@ -127,6 +129,8 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::LoDTensor& tmat,
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->value().dims()[1];
VLOG(30) << "sparse w_grad dims is: [" << weight->value().dims()[0] << " ,"
<< weight->value().dims()[1] << " ]";
auto tmat_value = tmat.data<T>();
auto weight_value = weight->mutable_value()->data<T>();
auto input_value = input.data<T>();

Loading…
Cancel
Save