|
|
|
@ -42,7 +42,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
|
|
|
|
|
for (auto i = 0; i < rows; i++) {
|
|
|
|
|
int32_t k = index[i];
|
|
|
|
|
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
|
|
|
|
|
PADDLE_ENFORCE_LT(k, ins.size(),
|
|
|
|
|
PADDLE_ENFORCE_LT((size_t)k, ins.size(),
|
|
|
|
|
"index exceeds the number of candidate tensors.");
|
|
|
|
|
memory::Copy(place, out->data<T>() + i * cols, place,
|
|
|
|
|
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
|
|
|
|
|