@ -43,7 +43,8 @@ Debugger::Debugger()
device_id_ ( 0 ) ,
num_step_ ( 0 ) ,
debugger_enabled_ ( false ) ,
is_dataset_graph_ ( false ) { }
is_dataset_graph_ ( false ) ,
partial_memory_ ( false ) { }
void Debugger : : Init ( const uint32_t device_id ) {
// access lock for public method
@ -57,6 +58,7 @@ void Debugger::EnableDebugger() {
// reset some of the class members
num_step_ = 0 ;
debugger_enabled_ = false ;
partial_memory_ = false ;
grpc_client_ = nullptr ;
debug_services_ = nullptr ;
@ -72,7 +74,8 @@ void Debugger::EnableDebugger() {
MS_LOG ( WARNING ) < < " Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger. " ;
return ;
}
// configure host
// configure grpc host
const char * env_host_str = std : : getenv ( " MS_DEBUGGER_HOST " ) ;
std : : string host ;
if ( env_host_str ! = nullptr ) {
@ -82,7 +85,7 @@ void Debugger::EnableDebugger() {
MS_LOG ( WARNING ) < < " Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost " ;
host = " localhost " ;
}
// configure port
// configure grpc port
const char * env_port_str = std : : getenv ( " MS_DEBUGGER_PORT " ) ;
std : : string port ;
if ( env_port_str ! = nullptr ) {
@ -93,6 +96,27 @@ void Debugger::EnableDebugger() {
port = " 50051 " ;
}
// configure partial memory reuse
const char * env_partial_mem_str = std : : getenv ( " MS_DEBUGGER_PARTIAL_MEM " ) ;
if ( env_partial_mem_str ! = nullptr ) {
MS_LOG ( INFO ) < < " Getenv MS_DEBUGGER_PARTIAL_MEM: " < < env_partial_mem_str ;
if ( std : : strcmp ( env_partial_mem_str , " 1 " ) = = 0 ) {
partial_memory_ = true ;
}
}
// switch memory reuse on or off
auto context_ptr = MsContext : : GetInstance ( ) ;
MS_EXCEPTION_IF_NULL ( context_ptr ) ;
context_ptr - > set_enable_mem_reuse ( partial_memory_ ) ;
// print some message about memory reuse to user
if ( partial_memory_ ) {
MS_LOG ( WARNING ) < < " Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
" step. 2. Tensor values are only available for nodes that are watched by any watchpoint. " ;
} else {
MS_LOG ( WARNING ) < < " Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
" usage for large models. " ;
}
// initialize grpc client
grpc_client_ = std : : make_unique < GrpcClient > ( host , port ) ;
debug_services_ = std : : make_unique < DebugServices > ( ) ;
@ -106,6 +130,7 @@ void Debugger::Reset() {
num_step_ = 0 ;
debugger_enabled_ = false ;
is_dataset_graph_ = false ;
partial_memory_ = false ;
graph_ptr_ = nullptr ;
grpc_client_ = nullptr ;
debug_services_ = nullptr ;
@ -317,11 +342,10 @@ void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCon
[ ] ( WatchNode node ) - > std : : tuple < std : : string , bool > {
return make_tuple ( node . node_name ( ) , node . node_type ( ) = = " scope " ) ;
} ) ;
debug_services_ - > add_watchpoint ( id , condition . condition ( ) , check_node_list ) ;
debug_services_ - > AddWatchpoint ( id , condition . condition ( ) , check_node_list ) ;
}
void Debugger : : RemoveWatchpoint ( const int32_t id ) { debug_services_ - > remove_w atchpoint( id ) ; }
void Debugger : : RemoveWatchpoint ( const int32_t id ) { debug_services_ - > RemoveW atchpoint( id ) ; }
std : : list < TensorProto > Debugger : : LoadTensors ( const ProtoVector < TensorProto > & tensors ) const {
std : : vector < std : : string > name ;
@ -335,7 +359,7 @@ std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &ten
// ret_name will contain tensor names that are found in TensorLoader
// items in ret_name will be in the same order with tensors if found
debug_services_ - > read_nodes_t ensors( name , & ret_name , & data_ptr , & data_size , & dtype , & shape ) ;
debug_services_ - > ReadNodesT ensors( name , & ret_name , & data_ptr , & data_size , & dtype , & shape ) ;
std : : list < TensorProto > tensor_list ;
unsigned int result_index = 0 ;
@ -384,8 +408,7 @@ std::list<WatchpointHit> Debugger::CheckWatchpoints() const {
std : : vector < int > condition ;
std : : vector < unsigned int > watchpoint_id ;
debug_services_ - > check_watchpoints ( & name , & slot , & data_ptr , & data_size , & condition , & watchpoint_id ) ;
debug_services_ - > CheckWatchpoints ( & name , & slot , & data_ptr , & data_size , & condition , & watchpoint_id ) ;
std : : list < WatchpointHit > hits ;
for ( unsigned int i = 0 ; i < name . size ( ) ; i + + ) {
WatchpointHit hit ;
@ -494,4 +517,6 @@ std::string GetTensorFullName(const TensorProto &tensor) {
return node_name + " : " + tensor . slot ( ) + ( tensor . iter ( ) = = " " ? " " : " : " + tensor . iter ( ) ) ;
}
bool Debugger : : partial_memory ( ) { return partial_memory_ ; }
} // namespace mindspore