@ -227,12 +227,18 @@ void CacheRowCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB,
void SparsePrefetchRowCpuMatrix : : addRows ( const unsigned int * ids , size_t len ) {
void SparsePrefetchRowCpuMatrix : : addRows ( const unsigned int * ids , size_t len ) {
std : : vector < unsigned int > & localIndices = indexDictHandle_ - > localIndices ;
std : : vector < unsigned int > & localIndices = indexDictHandle_ - > localIndices ;
for ( size_t i = 0 ; i < len ; i + + ) {
CHECK_LT ( * ( ids + i ) , this - > getHeight ( ) )
< < " id: " < < * ( ids + i ) < < " Height: " < < this - > getHeight ( )
< < " sparse id value exceeds the max input dimension, "
< < " it could be caused invalid input data samples " ;
}
localIndices . insert ( localIndices . end ( ) , ids , ids + len ) ;
localIndices . insert ( localIndices . end ( ) , ids , ids + len ) ;
}
}
void SparsePrefetchRowCpuMatrix : : addRows ( MatrixPtr input ) {
void SparsePrefetchRowCpuMatrix : : addRows ( MatrixPtr input ) {
CpuSparseMatrix * mat = dynamic_cast < CpuSparseMatrix * > ( input . get ( ) ) ;
CpuSparseMatrix * mat = dynamic_cast < CpuSparseMatrix * > ( input . get ( ) ) ;
CHECK ( mat ) < < " only support non value sparse matrix " ;
CHECK ( mat ) < < " only support sparse matrix" ;
addRows ( reinterpret_cast < const unsigned int * > ( mat - > getCols ( ) ) ,
addRows ( reinterpret_cast < const unsigned int * > ( mat - > getCols ( ) ) ,
mat - > getElementCnt ( ) ) ;
mat - > getElementCnt ( ) ) ;
}
}
@ -243,7 +249,13 @@ void SparsePrefetchRowCpuMatrix::addRows(IVectorPtr ids) {
int * index = ids - > getData ( ) ;
int * index = ids - > getData ( ) ;
for ( size_t i = 0 ; i < numSamples ; + + i ) {
for ( size_t i = 0 ; i < numSamples ; + + i ) {
if ( index [ i ] = = - 1 ) continue ;
if ( index [ i ] = = - 1 ) continue ;
localIndices . push_back ( ( unsigned int ) index [ i ] ) ;
unsigned int id = ( unsigned int ) index [ i ] ;
CHECK_LT ( id , this - > getHeight ( ) )
< < " id: " < < id < < " Height: " < < this - > getHeight ( )
< < " sparse id value exceeds the max input dimension, "
< < " it could be caused invalid input data samples " ;
localIndices . push_back ( id ) ;
}
}
}
}