@ -42,6 +42,7 @@
# include "dataset/util/status.h"
# include "dataset/util/task_manager.h"
# include "dataset/util/wait_post.h"
# include "utils/system/crc32c.h"
namespace mindspore {
namespace dataset {
@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder()
builder_data_schema_ = std : : make_unique < DataSchema > ( ) ;
}
bool ValidateFirstRowCrc ( const std : : string & filename ) {
std : : ifstream reader ;
reader . open ( filename ) ;
if ( ! reader ) {
return false ;
}
// read data
int64_t record_length = 0 ;
( void ) reader . read ( reinterpret_cast < char * > ( & record_length ) , static_cast < std : : streamsize > ( sizeof ( int64_t ) ) ) ;
// read crc from file
uint32_t masked_crc = 0 ;
( void ) reader . read ( reinterpret_cast < char * > ( & masked_crc ) , static_cast < std : : streamsize > ( sizeof ( uint32_t ) ) ) ;
// generate crc from data
uint32_t generated_crc =
system : : Crc32c : : GetMaskCrc32cValue ( reinterpret_cast < char * > ( & record_length ) , sizeof ( int64_t ) ) ;
return masked_crc = = generated_crc ;
}
Status TFReaderOp : : Builder : : ValidateInputs ( ) const {
std : : string err_msg ;
err_msg + = builder_num_workers_ < = 0 ? " Number of parallel workers is smaller or equal to 0 \n " : " " ;
if ( ! builder_equal_rows_per_shard_ ) {
err_msg + = builder_dataset_files_list_ . size ( ) < static_cast < uint32_t > ( builder_num_devices_ )
? " No enough tf_file files provided \n "
: " " ;
if ( builder_num_workers_ < = 0 ) {
err_msg + = " Number of parallel workers is smaller or equal to 0 \n " ;
}
if ( ! builder_equal_rows_per_shard_ & &
builder_dataset_files_list_ . size ( ) < static_cast < uint32_t > ( builder_num_devices_ ) ) {
err_msg + = " Not enough tfrecord files provided \n " ;
}
if ( builder_device_id_ > = builder_num_devices_ | | builder_num_devices_ < 1 ) {
err_msg + = " Wrong sharding configs \n " ;
}
err_msg + = builder_device_id_ > = builder_num_devices_ | | builder_num_devices_ < 1 ? " Wrong sharding configs \n " : " " ;
std : : vector < std : : string > invalid_files ( builder_dataset_files_list_ . size ( ) ) ;
auto it = std : : copy_if ( builder_dataset_files_list_ . begin ( ) , builder_dataset_files_list_ . end ( ) , invalid_files . begin ( ) ,
[ ] ( const std : : string & filename ) { return ! ValidateFirstRowCrc ( filename ) ; } ) ;
invalid_files . resize ( std : : distance ( invalid_files . begin ( ) , it ) ) ;
if ( ! invalid_files . empty ( ) ) {
err_msg + = " The following files either cannot be opened, or are not valid tfrecord files: \n " ;
std : : string accumulated_filenames = std : : accumulate (
invalid_files . begin ( ) , invalid_files . end ( ) , std : : string ( " " ) ,
[ ] ( const std : : string & accumulated , const std : : string & next ) { return accumulated + " " + next + " \n " ; } ) ;
err_msg + = accumulated_filenames ;
}
return err_msg . empty ( ) ? Status : : OK ( ) : Status ( StatusCode : : kUnexpectedError , __LINE__ , __FILE__ , err_msg ) ;
}
@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
RETURN_IF_NOT_OK ( LoadExample ( & tf_file , & new_tensor_table , rows_read ) ) ;
rows_read + + ;
}
// ignore crc footer
( void ) reader . ignore ( static_cast < std : : streamsize > ( sizeof ( int32_t ) ) ) ;
rows_total + + ;