|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -54,6 +54,9 @@ void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void GatherNd<double, int>(double *input, int *indices, double *output, const size_t &output_dim0,
|
|
|
|
|
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
|
|
|
|
int *batch_strides, cudaStream_t stream);
|
|
|
|
|
template void GatherNd<float, int>(float *input, int *indices, float *output, const size_t &output_dim0,
|
|
|
|
|
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
|
|
|
|
int *batch_strides, cudaStream_t stream);
|
|
|
|
@ -73,6 +76,9 @@ template void GatherNd<unsigned char, int>(unsigned char *input, int *indices, u
|
|
|
|
|
template void GatherNd<bool, int>(bool *input, int *indices, bool *output, const size_t &output_dim0,
|
|
|
|
|
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
|
|
|
|
int *batch_strides, cudaStream_t stream);
|
|
|
|
|
template void GatherNd<double, int64_t>(double *input, int64_t *indices, double *output, const size_t &output_dim0,
|
|
|
|
|
const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices,
|
|
|
|
|
int64_t *batch_strides, cudaStream_t stream);
|
|
|
|
|
template void GatherNd<float, int64_t>(float *input, int64_t *indices, float *output, const size_t &output_dim0,
|
|
|
|
|
const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices,
|
|
|
|
|
int64_t *batch_strides, cudaStream_t stream);
|
|
|
|
|