|
|
@ -67,9 +67,10 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
|
|
|
|
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
|
|
|
|
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
|
|
|
|
softmax_logits.set_lod(logits_lod);
|
|
|
|
softmax_logits.set_lod(logits_lod);
|
|
|
|
int rank = logits->dims().size();
|
|
|
|
int rank = logits->dims().size();
|
|
|
|
|
|
|
|
int axis_dim = logits->dims()[rank - 1];
|
|
|
|
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
|
|
|
|
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
|
|
|
|
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
|
|
|
|
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
|
|
|
|
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, -1, &in_2d,
|
|
|
|
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, axis_dim, &in_2d,
|
|
|
|
&out_2d);
|
|
|
|
&out_2d);
|
|
|
|
|
|
|
|
|
|
|
|
// ctc needs sequences data stored in transposed padding format
|
|
|
|
// ctc needs sequences data stored in transposed padding format
|
|
|
|