|
|
|
@ -186,8 +186,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// warpctc accesses labels in CPU memory
|
|
|
|
|
Tensor warpctc_label;
|
|
|
|
|
TensorCopy(*label, platform::CPUPlace(), ctx.device_context(),
|
|
|
|
|
&warpctc_label);
|
|
|
|
|
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
|
|
|
|
|
const int* warpctc_label_data = warpctc_label.data<int>();
|
|
|
|
|
// warpctc stores loss in CPU memory
|
|
|
|
|
Tensor warpctc_loss;
|
|
|
|
|