|
|
|
@ -232,7 +232,7 @@ public:
|
|
|
|
|
/// input grad and output grad have the same batch_size
|
|
|
|
|
CHECK_EQ(inouts[0].dims_[0], inputs[1].dims_[0]);
|
|
|
|
|
/// dim of output = dim of input * context_length
|
|
|
|
|
CHECK_EQ(inputs[1].dims_[1], inputs[0].dims_[1] * context_length_);
|
|
|
|
|
CHECK_EQ(inputs[1].dims_[1], inouts[0].dims_[1] * context_length_);
|
|
|
|
|
|
|
|
|
|
typename SequenceT<Device>::type seq_vec(
|
|
|
|
|
inputs[0].dims_[0], reinterpret_cast<int*>(inputs[0].getData()));
|
|
|
|
|