|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
#include <thrust/reduce.h>
|
|
|
|
|
#include "paddle/operators/accuracy_op.h"
|
|
|
|
|
#include "paddle/platform/cuda_helper.h"
|
|
|
|
|
#include "paddle/platform/gpu_info.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
int num_samples = static_cast<int>(inference->dims()[0]);
|
|
|
|
|
size_t infer_width = inference->dims()[1];
|
|
|
|
|
PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
|
|
|
|
|
// cudaMemset((void**)&correct_data, 0, sizeof(float));
|
|
|
|
|
auto stream = ctx.cuda_device_context().stream();
|
|
|
|
|
platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream);
|
|
|
|
|
|
|
|
|
|
if (num_samples == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice);
|
|
|
|
|
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
|
|
|
|
|
cudaMemcpyHostToDevice, stream);
|
|
|
|
|
|
|
|
|
|
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
|
|
|
|
|
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
AccuracyCudaKernel<
|
|
|
|
|
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
|
|
|
|
|
num_samples, infer_width, indices_data, label_data, correct_data,
|
|
|
|
|
accuracy_data);
|
|
|
|
|
|
|
|
|
|
int d_num_samples, d_num_correct;
|
|
|
|
|
float d_accuracy;
|
|
|
|
|
cudaMemcpy(&d_num_correct, correct_data, sizeof(int),
|
|
|
|
|
cudaMemcpyDeviceToHost);
|
|
|
|
|
cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost);
|
|
|
|
|
cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float),
|
|
|
|
|
cudaMemcpyDeviceToHost);
|
|
|
|
|
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
|
|
|
|
|
cudaMemcpyDeviceToHost, stream);
|
|
|
|
|
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
|
|
|
|
|
cudaMemcpyDeviceToHost, stream);
|
|
|
|
|
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
|
|
|
|
|
cudaMemcpyDeviceToHost, stream);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|