|
|
|
@ -43,7 +43,8 @@ Debugger::Debugger()
|
|
|
|
|
device_id_(0),
|
|
|
|
|
num_step_(0),
|
|
|
|
|
debugger_enabled_(false),
|
|
|
|
|
is_dataset_graph_(false) {}
|
|
|
|
|
is_dataset_graph_(false),
|
|
|
|
|
partial_memory_(false) {}
|
|
|
|
|
|
|
|
|
|
void Debugger::Init(const uint32_t device_id) {
|
|
|
|
|
// access lock for public method
|
|
|
|
@ -57,6 +58,7 @@ void Debugger::EnableDebugger() {
|
|
|
|
|
// reset some of the class members
|
|
|
|
|
num_step_ = 0;
|
|
|
|
|
debugger_enabled_ = false;
|
|
|
|
|
partial_memory_ = false;
|
|
|
|
|
grpc_client_ = nullptr;
|
|
|
|
|
debug_services_ = nullptr;
|
|
|
|
|
|
|
|
|
@ -72,7 +74,8 @@ void Debugger::EnableDebugger() {
|
|
|
|
|
MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// configure host
|
|
|
|
|
|
|
|
|
|
// configure grpc host
|
|
|
|
|
const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
|
|
|
|
|
std::string host;
|
|
|
|
|
if (env_host_str != nullptr) {
|
|
|
|
@ -82,7 +85,7 @@ void Debugger::EnableDebugger() {
|
|
|
|
|
MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
|
|
|
|
|
host = "localhost";
|
|
|
|
|
}
|
|
|
|
|
// configure port
|
|
|
|
|
// configure grpc port
|
|
|
|
|
const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
|
|
|
|
|
std::string port;
|
|
|
|
|
if (env_port_str != nullptr) {
|
|
|
|
@ -93,6 +96,27 @@ void Debugger::EnableDebugger() {
|
|
|
|
|
port = "50051";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// configure partial memory reuse
|
|
|
|
|
const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
|
|
|
|
|
if (env_partial_mem_str != nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
|
|
|
|
|
if (std::strcmp(env_partial_mem_str, "1") == 0) {
|
|
|
|
|
partial_memory_ = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// switch memory reuse on or off
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
context_ptr->set_enable_mem_reuse(partial_memory_);
|
|
|
|
|
// print some message about memory reuse to user
|
|
|
|
|
if (partial_memory_) {
|
|
|
|
|
MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
|
|
|
|
|
"step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
|
|
|
|
|
"usage for large models.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// initialize grpc client
|
|
|
|
|
grpc_client_ = std::make_unique<GrpcClient>(host, port);
|
|
|
|
|
debug_services_ = std::make_unique<DebugServices>();
|
|
|
|
@ -106,6 +130,7 @@ void Debugger::Reset() {
|
|
|
|
|
num_step_ = 0;
|
|
|
|
|
debugger_enabled_ = false;
|
|
|
|
|
is_dataset_graph_ = false;
|
|
|
|
|
partial_memory_ = false;
|
|
|
|
|
graph_ptr_ = nullptr;
|
|
|
|
|
grpc_client_ = nullptr;
|
|
|
|
|
debug_services_ = nullptr;
|
|
|
|
@ -317,11 +342,10 @@ void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCon
|
|
|
|
|
[](WatchNode node) -> std::tuple<std::string, bool> {
|
|
|
|
|
return make_tuple(node.node_name(), node.node_type() == "scope");
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
debug_services_->add_watchpoint(id, condition.condition(), check_node_list);
|
|
|
|
|
debug_services_->AddWatchpoint(id, condition.condition(), check_node_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->remove_watchpoint(id); }
|
|
|
|
|
void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
|
|
|
|
|
|
|
|
|
|
std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
|
|
|
|
|
std::vector<std::string> name;
|
|
|
|
@ -335,7 +359,7 @@ std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &ten
|
|
|
|
|
|
|
|
|
|
// ret_name will contain tensor names that are found in TensorLoader
|
|
|
|
|
// items in ret_name will be in the same order with tensors if found
|
|
|
|
|
debug_services_->read_nodes_tensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
|
|
|
|
|
debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
|
|
|
|
|
|
|
|
|
|
std::list<TensorProto> tensor_list;
|
|
|
|
|
unsigned int result_index = 0;
|
|
|
|
@ -384,8 +408,7 @@ std::list<WatchpointHit> Debugger::CheckWatchpoints() const {
|
|
|
|
|
std::vector<int> condition;
|
|
|
|
|
std::vector<unsigned int> watchpoint_id;
|
|
|
|
|
|
|
|
|
|
debug_services_->check_watchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id);
|
|
|
|
|
|
|
|
|
|
debug_services_->CheckWatchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id);
|
|
|
|
|
std::list<WatchpointHit> hits;
|
|
|
|
|
for (unsigned int i = 0; i < name.size(); i++) {
|
|
|
|
|
WatchpointHit hit;
|
|
|
|
@ -494,4 +517,6 @@ std::string GetTensorFullName(const TensorProto &tensor) {
|
|
|
|
|
return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Debugger::partial_memory() { return partial_memory_; }
|
|
|
|
|
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|