Keep parameters of the previous step in TensorLoader

Add name truncating to support mindinsight loading parameter

Refactoring and address review comments
pull/2592/head
Shida He 5 years ago
parent c9929fd8a1
commit cb4c74c7c0

@ -313,4 +313,10 @@ message TensorProto {
// If the tensor content transferring is finished.
optional bool finished = 6;
// The iteration of the tensor. Supported: "prev" or leave empty.
optional string iter = 7;
// If the tensor name should be truncated.
optional bool truncate = 8;
}

File diff suppressed because it is too large Load Diff

@ -72,9 +72,9 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// suspend the execution after a debug_op
void PostDebugOp();
DebugServices *get_debug_services();
DebugServices *debug_services() const;
bool debugger_enabled();
bool debugger_enabled() const;
private:
// private constructor for singleton
@ -92,7 +92,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void CheckDatasetGraph();
// serialize graph and get proto
GraphProto GetGraphProto();
GraphProto GetGraphProto() const;
// send graph and enter command wait loop
void SendGraphAndSuspend(const GraphProto &graph_proto);
@ -102,16 +102,6 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// break if RunCMD
void CommandLoop();
// process reply and command type
DebuggerCommand GetCommand(const EventReply &reply);
// parse other data out of EventReply
ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply);
WatchCondition GetWatchcondition(const EventReply &reply);
int32_t GetWatchpointID(const EventReply &reply);
bool GetWatchpointDelete(const EventReply &reply);
ProtoVector<TensorProto> GetTensors(const EventReply &reply);
// set what nodes and conditions to watch
void SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id);
@ -119,14 +109,14 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void RemoveWatchpoint(const int32_t id);
// load tensor for view command
std::list<TensorProto> LoadTensors(const ProtoVector<TensorProto> &tensors);
std::list<TensorProto> LoadTensors(const ProtoVector<TensorProto> &tensors) const;
// terminate training process
void Exit();
// analyze tensors and check watchpoint conditions
// return names of tensors and what condition they hit
std::list<WatchpointHit> CheckWatchpoints();
std::list<WatchpointHit> CheckWatchpoints() const;
// send watchpoints that hit and enter command wait loop
void SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points);
@ -155,5 +145,18 @@ ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph);
// for getting proto DataType from Type of Tensor
DataType GetDebuggerNumberDataType(const TypePtr &type);
// process reply and command type
DebuggerCommand GetCommand(const EventReply &reply);
// parse other data out of EventReply
ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply);
WatchCondition GetWatchcondition(const EventReply &reply);
int32_t GetWatchpointID(const EventReply &reply);
bool GetWatchpointDelete(const EventReply &reply);
ProtoVector<TensorProto> GetTensors(const EventReply &reply);
// get the full name of a tensor, which is the name used in TensorLoader
std::string GetTensorFullName(const TensorProto &tensor);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_

@ -21,6 +21,7 @@
#include <map>
#include <tuple>
#include <string>
#include <utility>
#include "debug/tensor_data.h"
namespace mindspore {
class TensorLoader {
@ -29,7 +30,15 @@ class TensorLoader {
~TensorLoader() {}
bool LoadNewTensor(std::shared_ptr<TensorData> tensor) {
bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) {
if (keep_prev) {
// add prev step tensor into current step map with ":prev" suffix
auto handle = prev_tensor_list_map.extract(tensor->GetName());
if (!handle.empty()) {
handle.key() = tensor->GetName() + ":prev";
tensor_list_map.insert(std::move(handle));
}
}
tensor_list.push_back(tensor);
tensor_list_map.insert({tensor->GetName(), tensor});
return true;
@ -53,16 +62,20 @@ class TensorLoader {
}
bool EmptyTensor() {
tensor_list_map.clear();
prev_tensor_list_map.clear();
tensor_list_map.swap(prev_tensor_list_map);
tensor_list.clear();
return true;
}
void EmptyPrevTensor() { prev_tensor_list_map.clear(); }
void set_iter_num(uint32_t iter_num) { this->iter_num = iter_num; }
private:
std::vector<std::shared_ptr<TensorData>> tensor_list;
std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map;
std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map;
uint32_t iter_num;
};
} // namespace mindspore

@ -370,10 +370,10 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file
#ifdef ENABLE_DEBUGGER
bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order,
const std::string &host_fmt, const std::vector<int> &host_shape,
TypeId host_type, size_t slot, Debugger *debugger) const {
TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const {
bool ret = false;
DebugServices *debug_services = debugger->get_debug_services();
DebugServices *debug_services = debugger->debug_services();
TensorLoader *tensor_loader = debug_services->get_tensor_loader();
if (trans_flag) {
@ -390,7 +390,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
tensor_data->SetExecutionOrder(execution_order);
tensor_data->SetTensor(out_tensor);
tensor_data->SetSlot(slot);
ret = tensor_loader->LoadNewTensor(tensor_data);
ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev);
} else {
mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape);
size_t host_size = out_tensor->data().nbytes();
@ -401,7 +401,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
tensor_data->SetExecutionOrder(execution_order);
tensor_data->SetTensor(out_tensor);
tensor_data->SetSlot(slot);
ret = tensor_loader->LoadNewTensor(tensor_data);
ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]";
}

@ -46,7 +46,8 @@ class AscendDeviceAddress : public DeviceAddress {
#endif
#ifdef ENABLE_DEBUGGER
bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt,
const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const;
const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger,
bool keep_prev) const;
#endif
private:

@ -322,7 +322,8 @@ void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) {
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); });
}
auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger);
auto ret =
ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger, false);
if (!ret) {
MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name
<< ", host_format:" << format << ".!";
@ -356,7 +357,8 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger)
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); });
}
auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger);
auto ret =
ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger, true);
if (!ret) {
MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name
<< ", host_format:" << format << ".!";

@ -799,12 +799,13 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
#ifdef ENABLE_DEBUGGER
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
DebugServices *debug_services = debugger_->get_debug_services();
DebugServices *debug_services = debugger_->debug_services();
TensorLoader *tensor_loader = debug_services->get_tensor_loader();
tensor_loader->EmptyTensor();
uint32_t iter_num = tensor_loader->GetIterNum();
tensor_loader->set_iter_num(++iter_num);
(void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get());
tensor_loader->EmptyPrevTensor();
#endif
MS_LOG(INFO) << "Finish!";
}

Loading…
Cancel
Save