add input sparse data check for sparse layer at runtime (#247)

* add input sparse data check for sparse layer at runtime,
to avoid invalid data access at pserver end while doing prefetch

* remote sparse design support binary sparse and float saprse both
avx_docs
backyes 9 years ago committed by Yu Yang
parent d1d52bb7d4
commit 46bd5f53e3

@ -227,12 +227,18 @@ void CacheRowCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB,
void SparsePrefetchRowCpuMatrix::addRows(const unsigned int* ids, size_t len) { void SparsePrefetchRowCpuMatrix::addRows(const unsigned int* ids, size_t len) {
std::vector<unsigned int>& localIndices = indexDictHandle_->localIndices; std::vector<unsigned int>& localIndices = indexDictHandle_->localIndices;
for (size_t i = 0; i < len; i ++) {
CHECK_LT(*(ids + i), this->getHeight())
<< "id:" << *(ids + i) << "Height:" << this->getHeight()
<< "sparse id value exceeds the max input dimension, "
<< "it could be caused invalid input data samples";
}
localIndices.insert(localIndices.end(), ids, ids + len); localIndices.insert(localIndices.end(), ids, ids + len);
} }
void SparsePrefetchRowCpuMatrix::addRows(MatrixPtr input) { void SparsePrefetchRowCpuMatrix::addRows(MatrixPtr input) {
CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(input.get()); CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(input.get());
CHECK(mat) << "only support non value sparse matrix"; CHECK(mat) << "only support sparse matrix";
addRows(reinterpret_cast<const unsigned int*>(mat->getCols()), addRows(reinterpret_cast<const unsigned int*>(mat->getCols()),
mat->getElementCnt()); mat->getElementCnt());
} }
@ -243,7 +249,13 @@ void SparsePrefetchRowCpuMatrix::addRows(IVectorPtr ids) {
int* index = ids->getData(); int* index = ids->getData();
for (size_t i = 0; i < numSamples; ++i) { for (size_t i = 0; i < numSamples; ++i) {
if (index[i] == -1) continue; if (index[i] == -1) continue;
localIndices.push_back((unsigned int)index[i]);
unsigned int id = (unsigned int)index[i];
CHECK_LT(id, this->getHeight())
<< "id:" << id << "Height:" << this->getHeight()
<< "sparse id value exceeds the max input dimension, "
<< "it could be caused invalid input data samples";
localIndices.push_back(id);
} }
} }

Loading…
Cancel
Save