|
|
|
@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <thrust/execution_policy.h>
|
|
|
|
|
#include <thrust/reduce.h>
|
|
|
|
|
#include "paddle/operators/accuracy_op.h"
|
|
|
|
|
#include "paddle/platform/cuda_helper.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
using platform::PADDLE_CUDA_NUM_THREADS;
|
|
|
|
|
|
|
|
|
|
__global__ void AccuracySingleKernel(const int N, const int D, const int top_k,
|
|
|
|
|
const int* Xdata, const int* labelData,
|
|
|
|
|
float* accuracy) {
|
|
|
|
|
int correct = 0;
|
|
|
|
|
for (int row = 0; row < N; row++) {
|
|
|
|
|
const int label = labelData[row];
|
|
|
|
|
for (int col = 0; col < D; col++) {
|
|
|
|
|
const int pred = Xdata[row * D + col];
|
|
|
|
|
if (pred == label) {
|
|
|
|
|
++correct;
|
|
|
|
|
template <int BlockSize>
|
|
|
|
|
__global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
|
|
|
|
|
const int* labeldata, float* accuracy) {
|
|
|
|
|
int count = 0;
|
|
|
|
|
__shared__ int total[BlockSize];
|
|
|
|
|
|
|
|
|
|
// support only 1 block
|
|
|
|
|
for (int i = threadIdx.x; i < (N); i += BlockSize) {
|
|
|
|
|
for (int j = 0; j < D; ++j) {
|
|
|
|
|
if (Xdata[i * D + j] == labeldata[i]) {
|
|
|
|
|
++count;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
*accuracy = static_cast<float>(correct) / static_cast<float>(N);
|
|
|
|
|
total[threadIdx.x] = count;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
// reduce the count with init value 0, and output accuracy.
|
|
|
|
|
int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
*accuracy = static_cast<float>(result) / static_cast<float>(N);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data,
|
|
|
|
|
label_data, accuracy_data);
|
|
|
|
|
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS>>>(
|
|
|
|
|
num_samples, infer_width, inference_data, label_data, accuracy_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|