clean a little bit code.

avx_docs
xutianbing 8 years ago
parent 86fa8c0528
commit df66957ec3

@ -232,7 +232,7 @@ public:
/// input grad and output grad have the same batch_size
CHECK_EQ(inouts[0].dims_[0], inputs[1].dims_[0]);
/// dim of output = dim of input * context_length
CHECK_EQ(inputs[1].dims_[1], inputs[0].dims_[1] * context_length_);
CHECK_EQ(inputs[1].dims_[1], inouts[0].dims_[1] * context_length_);
typename SequenceT<Device>::type seq_vec(
inputs[0].dims_[0], reinterpret_cast<int*>(inputs[0].getData()));

@ -256,7 +256,7 @@ __global__ void KeContextProjectionBackwardWeight(const real* out_grad,
for (int seqId = idy; seqId < num_sequences; seqId += THREADS_Y) {
int seq_start = sequence[seqId];
int seq_end = sequence[seqId+1];
output_r = const_cast<real*>(out_grad)
output_r = const_cast<real*>(out_grad)
+ seq_start * w_dim * context_length;
if (context_start < 0) {

Loading…
Cancel
Save