|
|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/top_k_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/assert.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
|
|
|
|
|
sh_topk[tid] = topk[*beam];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// temporary solution
|
|
|
|
|
unsigned mask = 0u;
|
|
|
|
|
CREATE_SHFL_MASK(mask, true);
|
|
|
|
|
|
|
|
|
|
if (maxid[0] / 32 == warp) {
|
|
|
|
|
if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break;
|
|
|
|
|
if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|