|
|
|
@ -73,11 +73,12 @@ class WarpCTCFunctor {
|
|
|
|
|
"Bytes of workspace got by warp-ctc function, "
|
|
|
|
|
"get_workspace_size(), should be larger than 0.");
|
|
|
|
|
|
|
|
|
|
Tensor workspace;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL;
|
|
|
|
|
float* workspace_data = workspace.mutable_data<float>(
|
|
|
|
|
Tensor workspace = ctx.AllocateTmpTensor<float, DeviceContext>(
|
|
|
|
|
framework::make_ddim({static_cast<int64_t>(workspace_elements)}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
dev_ctx);
|
|
|
|
|
float* workspace_data = workspace.data<float>();
|
|
|
|
|
math::SetConstant<DeviceContext, float>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), &workspace,
|
|
|
|
|
static_cast<float>(0));
|
|
|
|
@ -186,8 +187,10 @@ class WarpCTCKernel : public framework::OpKernel<T> {
|
|
|
|
|
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
|
|
|
|
|
static_cast<int64_t>(num_sequences),
|
|
|
|
|
static_cast<int64_t>(sequence_width)});
|
|
|
|
|
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
Tensor warpctc_logits_tmp =
|
|
|
|
|
ctx.AllocateTmpTensor<T, DeviceContext>(warpctc_logits_dims, dev_ctx);
|
|
|
|
|
warpctc_logits.ShareDataWith(warpctc_logits_tmp);
|
|
|
|
|
if (ctx.HasInput("LogitsLength")) {
|
|
|
|
|
TensorCopySync(*logits, ctx.GetPlace(), &warpctc_logits);
|
|
|
|
|
} else {
|
|
|
|
|