|
|
|
@ -78,45 +78,44 @@ public:
|
|
|
|
|
useGpu(arguments[0].deviceId));
|
|
|
|
|
const MatrixPtr errorMat2 = Matrix::create(output->getHeight(),
|
|
|
|
|
1,
|
|
|
|
|
/* trans= */ false, false);
|
|
|
|
|
// useGpu(arguments[0].deviceId));
|
|
|
|
|
/* trans= */ false,
|
|
|
|
|
false);
|
|
|
|
|
|
|
|
|
|
errorMat->zeroMem();
|
|
|
|
|
|
|
|
|
|
if (label != nullptr) {
|
|
|
|
|
errorMat->classificationError(*output, *label); // top-1 error
|
|
|
|
|
size_t height = output->getHeight();
|
|
|
|
|
size_t width = 5; // config_.num_results();
|
|
|
|
|
size_t width = 5;
|
|
|
|
|
|
|
|
|
|
IVector::resizeOrCreate(maxIds_, height * width,
|
|
|
|
|
useGpu(arguments[0].deviceId));
|
|
|
|
|
Matrix::resizeOrCreate(maxValues_, height, width, false,
|
|
|
|
|
useGpu(arguments[0].deviceId));
|
|
|
|
|
IVector::resizeOrCreate(
|
|
|
|
|
maxIds_, height * width, useGpu(arguments[0].deviceId));
|
|
|
|
|
Matrix::resizeOrCreate(
|
|
|
|
|
maxValues_, height, width, false, useGpu(arguments[0].deviceId));
|
|
|
|
|
output->rowMax(*maxIds_, *maxValues_); // top-5 values
|
|
|
|
|
|
|
|
|
|
int* ids;
|
|
|
|
|
int* lbl;
|
|
|
|
|
int* ids = nullptr;
|
|
|
|
|
int* lbl = nullptr;
|
|
|
|
|
if (useGpu(arguments[0].deviceId)) {
|
|
|
|
|
IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
|
|
|
|
|
hl_memcpy_device2host((void*)dest->getData(),
|
|
|
|
|
(void*)maxIds_->getData(),
|
|
|
|
|
sizeof(int) * maxIds_->getSize());
|
|
|
|
|
(void*)maxIds_->getData(),
|
|
|
|
|
sizeof(int) * maxIds_->getSize());
|
|
|
|
|
ids = dest->getData();
|
|
|
|
|
|
|
|
|
|
IVectorPtr dest2 = IVector::create(label->getSize(), false);
|
|
|
|
|
hl_memcpy_device2host((void*)dest2->getData(),
|
|
|
|
|
(void*)label->getData(),
|
|
|
|
|
sizeof(int) * label->getSize());
|
|
|
|
|
(void*)label->getData(),
|
|
|
|
|
sizeof(int) * label->getSize());
|
|
|
|
|
lbl = dest2->getData();
|
|
|
|
|
} else {
|
|
|
|
|
ids = maxIds_->getData();
|
|
|
|
|
lbl = label->getData();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// real* result = errorMat->getData();
|
|
|
|
|
real* result2 = errorMat2->getData();
|
|
|
|
|
for (size_t i = 0; i < height; ++i) {
|
|
|
|
|
// result[i] = (ids[i * width] != lbl[i]); // top-1 error
|
|
|
|
|
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error
|
|
|
|
|
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error
|
|
|
|
|
for (size_t j = 1; j < width; ++j) {
|
|
|
|
|
if (result2[i] == 0.0) {
|
|
|
|
|
break;
|
|
|
|
@ -141,10 +140,8 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void printStats(std::ostream& os) const {
|
|
|
|
|
os << "top_1_error="
|
|
|
|
|
<< (numSamples_ ? totalScore_ / numSamples_ : 0)
|
|
|
|
|
<< " top_5_error="
|
|
|
|
|
<< (numSamples_ ? totalScore2_ / numSamples_ : 0);
|
|
|
|
|
os << "top_1_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0)
|
|
|
|
|
<< " top_5_error=" << (numSamples_ ? totalScore2_ / numSamples_ : 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual real evalImp(std::vector<Argument>& arguments) {
|
|
|
|
@ -156,7 +153,6 @@ public:
|
|
|
|
|
mergeResultsOfAllClients(client);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
IVectorPtr maxIds_;
|
|
|
|
|
MatrixPtr maxValues_;
|
|
|
|
|