|
|
|
@ -106,7 +106,11 @@ class TargetAssignKernel : public framework::OpKernel<T> {
|
|
|
|
|
int64_t k = x->dims()[2];
|
|
|
|
|
|
|
|
|
|
auto x_lod = x->lod().back();
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA)
|
|
|
|
|
size_t* x_lod_data = x_lod.MutableData(ctx.GetPlace());
|
|
|
|
|
#else
|
|
|
|
|
size_t* x_lod_data = x_lod.data();
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
TargetAssignFunctor<T, WT> functor(x_data, match_idx_data, x_lod_data,
|
|
|
|
|
mismatch_value, n, m, p, k, out_data,
|
|
|
|
@ -121,7 +125,11 @@ class TargetAssignKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(neg_indices->lod().size(), 1UL);
|
|
|
|
|
const int* neg_idx_data = neg_indices->data<int>();
|
|
|
|
|
auto neg_lod = neg_indices->lod().back();
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA)
|
|
|
|
|
size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace());
|
|
|
|
|
#else
|
|
|
|
|
size_t* neg_lod_data = neg_lod.data();
|
|
|
|
|
#endif
|
|
|
|
|
NegTargetAssignFunctor<DeviceContext, T, WT> neg_trg_functor;
|
|
|
|
|
neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, n, m, k,
|
|
|
|
|
mismatch_value, out_data, out_wt_data);
|
|
|
|
|