@ -128,63 +128,93 @@ class WarpCTCKernel : public framework::OpKernel<T> {
auto * warpctc_grad = ctx . Output < Tensor > ( " WarpCTCGrad " ) ;
auto * loss = ctx . Output < Tensor > ( " Loss " ) ;
const size_t level = 0 ;
auto logits_lod = framework : : ToAbsOffset ( logits - > lod ( ) ) ;
auto logits_dims = logits - > dims ( ) ;
PADDLE_ENFORCE_EQ ( logits_dims [ 0 ] ,
static_cast < int64_t > ( logits_lod [ level ] . back ( ) ) ,
" The first dimension of Input(Logits) should be equal to "
" the sum of all sequences' lengths. " ) ;
auto label_lod = framework : : ToAbsOffset ( label - > lod ( ) ) ;
auto label_dims = label - > dims ( ) ;
PADDLE_ENFORCE_EQ (
label_dims [ 0 ] , label - > numel ( ) ,
" The width of each timestep in Input(Label) should be 1. " ) ;
const size_t num_sequences = logits_lod [ level ] . size ( ) - 1 ;
PADDLE_ENFORCE_EQ ( num_sequences , label_lod [ level ] . size ( ) - 1 ,
" The number of sequences of Input(Logits) should be "
" equal to that of Input(Label). " ) ;
const size_t sequence_width = logits - > numel ( ) / logits_dims [ 0 ] ;
size_t num_sequences , sequence_width , max_sequence_length ;
framework : : Vector < size_t > logits_lod ;
framework : : Vector < size_t > label_lod ;
if ( ctx . HasInput ( " LogitsLength " ) & & ctx . HasInput ( " LabelLength " ) ) {
num_sequences = logits - > dims ( ) [ 1 ] ;
sequence_width = logits - > dims ( ) [ 2 ] ;
max_sequence_length = logits - > dims ( ) [ 0 ] ;
auto * logits_length = ctx . Input < framework : : Tensor > ( " LogitsLength " ) ;
auto * labels_length = ctx . Input < framework : : Tensor > ( " LabelLength " ) ;
framework : : Tensor logits_length_cpu ;
framework : : Tensor labels_length_cpu ;
framework : : TensorCopy ( * logits_length , platform : : CPUPlace ( ) ,
& logits_length_cpu ) ;
framework : : TensorCopy ( * labels_length , platform : : CPUPlace ( ) ,
& labels_length_cpu ) ;
logits_lod . push_back ( 0 ) ;
label_lod . push_back ( 0 ) ;
for ( auto i = 0 ; i < num_sequences ; i + + ) {
logits_lod . push_back ( logits_lod [ i ] +
logits_length_cpu . data < int64_t > ( ) [ i ] ) ;
label_lod . push_back ( label_lod [ i ] +
labels_length_cpu . data < int64_t > ( ) [ i ] ) ;
}
} else {
logits_lod = framework : : ToAbsOffset ( logits - > lod ( ) ) [ 0 ] ;
auto logits_dims = logits - > dims ( ) ;
PADDLE_ENFORCE_EQ (
logits_dims [ 0 ] , static_cast < int64_t > ( logits_lod . back ( ) ) ,
" The first dimension of Input(Logits) should be equal to "
" the sum of all sequences' lengths. " ) ;
label_lod = framework : : ToAbsOffset ( label - > lod ( ) ) [ 0 ] ;
auto label_dims = label - > dims ( ) ;
PADDLE_ENFORCE_EQ (
label_dims [ 0 ] , label - > numel ( ) ,
" The width of each timestep in Input(Label) should be 1. " ) ;
num_sequences = logits_lod . size ( ) - 1 ;
PADDLE_ENFORCE_EQ ( num_sequences , label_lod . size ( ) - 1 ,
" The number of sequences of Input(Logits) should be "
" equal to that of Input(Label). " ) ;
sequence_width = logits - > numel ( ) / logits_dims [ 0 ] ;
max_sequence_length = math : : MaximumSequenceLength ( logits_lod ) ;
}
auto loss_dims =
framework : : make_ddim ( { static_cast < int64_t > ( num_sequences ) , 1 } ) ;
// warpctc needs sequences data stored in transposed padding format
LoDTensor warpctc_logits ;
const size_t max_sequence_length =
math : : MaximumSequenceLength ( logits_lod [ level ] ) ;
auto warpctc_logits_dims =
framework : : make_ddim ( { static_cast < int64_t > ( max_sequence_length ) ,
static_cast < int64_t > ( num_sequences ) ,
static_cast < int64_t > ( sequence_width ) } ) ;
warpctc_logits . mutable_data < T > ( warpctc_logits_dims , ctx . GetPlace ( ) ) ;
LoDTensor cpu_pad_value ;
T * pad_value_data =
cpu_pad_value . mutable_data < T > ( { 1 } , platform : : CPUPlace ( ) ) ;
* pad_value_data = static_cast < T > ( 0 ) ;
LoDTensor pad_value ;
if ( platform : : is_cpu_place ( ctx . GetPlace ( ) ) ) {
pad_value = cpu_pad_value ;
if ( ctx . HasInput ( " LogitsLength " ) ) {
TensorCopySync ( * logits , ctx . GetPlace ( ) , & warpctc_logits ) ;
} else {
TensorCopySync ( cpu_pad_value , ctx . GetPlace ( ) , & pad_value ) ;
LoDTensor cpu_pad_value ;
T * pad_value_data =
cpu_pad_value . mutable_data < T > ( { 1 } , platform : : CPUPlace ( ) ) ;
* pad_value_data = static_cast < T > ( 0 ) ;
LoDTensor pad_value ;
if ( platform : : is_cpu_place ( ctx . GetPlace ( ) ) ) {
pad_value = cpu_pad_value ;
} else {
TensorCopySync ( cpu_pad_value , ctx . GetPlace ( ) , & pad_value ) ;
}
math : : PaddingLoDTensorFunctor < DeviceContext , T > ( ) (
ctx . template device_context < DeviceContext > ( ) , * logits ,
& warpctc_logits , pad_value , - 1 , 0 , false /* norm_by_times */ ,
math : : kLengthBatchWidth ) ;
}
math : : PaddingLoDTensorFunctor < DeviceContext , T > ( ) (
ctx . template device_context < DeviceContext > ( ) , * logits , & warpctc_logits ,
pad_value , - 1 , 0 , false /* norm_by_times */ , math : : kLengthBatchWidth ) ;
const T * warpctc_logits_data = warpctc_logits . data < T > ( ) ;
std : : vector < int > warpctc_label_lengths ( num_sequences ) ;
std : : vector < int > warpctc_logits_lengths ( num_sequences ) ;
for ( size_t i = 0 ; i < num_sequences ; + + i ) {
warpctc_label_lengths [ i ] = label_lod [ level ] [ i + 1 ] - label_lod [ level ] [ i ] ;
warpctc_logits_lengths [ i ] =
logits_lod [ level ] [ i + 1 ] - logits_lod [ level ] [ i ] ;
warpctc_label_lengths [ i ] = label_lod [ i + 1 ] - label_lod [ i ] ;
warpctc_logits_lengths [ i ] = logits_lod [ i + 1 ] - logits_lod [ i ] ;
}
// warpctc computes loss and gradient in one call, gradient data also stored
@ -199,6 +229,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
// warpctc accesses labels in CPU memory
Tensor warpctc_label ;
TensorCopySync ( * label , platform : : CPUPlace ( ) , & warpctc_label ) ;
const int * warpctc_label_data = warpctc_label . data < int > ( ) ;
// warpctc stores loss in CPU memory
Tensor warpctc_loss ;
@ -227,14 +258,53 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
logits_grad - > mutable_data < T > ( ctx . GetPlace ( ) ) ;
bool norm_by_times = ctx . Attr < bool > ( " norm_by_times " ) ;
math : : UnpaddingLoDTensorFunctor < DeviceContext , T > ( ) (
ctx . template device_context < DeviceContext > ( ) , * warpctc_grad ,
logits_grad , - 1 , 0 , norm_by_times , math : : kLengthBatchWidth ) ;
const T * loss_grad_data = loss_grad - > data < T > ( ) ;
math : : ScaleLoDTensorFunctor < DeviceContext , T > ( ) (
ctx . template device_context < DeviceContext > ( ) , loss_grad_data ,
logits_grad ) ;
if ( ctx . HasInput ( " LogitsLength " ) ) {
size_t max_seq_length = warpctc_grad - > dims ( ) [ 0 ] ;
size_t num_sequences = warpctc_grad - > dims ( ) [ 1 ] ;
size_t seq_width = warpctc_grad - > dims ( ) [ 2 ] ;
LoDTensor logits_grad_with_lod ;
auto logits_grad_dims =
framework : : make_ddim ( { static_cast < int64_t > ( max_seq_length ) ,
static_cast < int64_t > ( num_sequences ) ,
static_cast < int64_t > ( seq_width ) } ) ;
T * logits_grad_cpu_data = logits_grad_with_lod . mutable_data < T > (
logits_grad_dims , platform : : CPUPlace ( ) ) ;
TensorCopySync ( * warpctc_grad , platform : : CPUPlace ( ) ,
& logits_grad_with_lod ) ;
Tensor loss_grad_cpu ;
loss_grad_cpu . mutable_data < T > ( loss_grad - > dims ( ) , platform : : CPUPlace ( ) ) ;
TensorCopySync ( * loss_grad , platform : : CPUPlace ( ) , & loss_grad_cpu ) ;
LoDTensor scaled_logits ;
T * scaled_logits_data =
scaled_logits . mutable_data < T > ( logits_grad_dims , platform : : CPUPlace ( ) ) ;
const T * loss_grad_data = loss_grad_cpu . data < T > ( ) ;
for ( size_t i = 0 ; i < max_seq_length ; + + i ) {
for ( size_t j = 0 ; j < num_sequences ; + + j ) {
for ( size_t k = 0 ; k < seq_width ; + + k ) {
size_t idx = i * ( num_sequences * seq_width ) + j * seq_width + k ;
scaled_logits_data [ idx ] =
logits_grad_cpu_data [ idx ] * loss_grad_data [ j ] ;
}
}
}
TensorCopySync ( scaled_logits , ctx . GetPlace ( ) , logits_grad ) ;
} else {
math : : UnpaddingLoDTensorFunctor < DeviceContext , T > ( ) (
ctx . template device_context < DeviceContext > ( ) , * warpctc_grad ,
logits_grad , - 1 , 0 , norm_by_times , math : : kLengthBatchWidth ) ;
const T * loss_grad_data = loss_grad - > data < T > ( ) ;
math : : ScaleLoDTensorFunctor < DeviceContext , T > ( ) (
ctx . template device_context < DeviceContext > ( ) , loss_grad_data ,
logits_grad ) ;
}
}
} ;