handle getnext output tensor as normal lifelong tensor

pull/8603/head
laiyongqiang 4 years ago
parent 95ce3eab08
commit b8821bb2f3

@ -273,8 +273,6 @@ void Somas::GetNextOutputProcess(const session::KernelGraph *graph) {
if (iter != nodes_map_.end()) { if (iter != nodes_map_.end()) {
auto getnext_output_tensors = iter->second->output_tensors_; auto getnext_output_tensors = iter->second->output_tensors_;
for (auto &tensor : getnext_output_tensors) { for (auto &tensor : getnext_output_tensors) {
if (tensor->contiguous_) continue;
tensor->offset_ = total_size;
total_size += tensor->GetAlignedSize(); total_size += tensor->GetAlignedSize();
tensor->lifelong_value_ = kLifeLongGraphAll; tensor->lifelong_value_ = kLifeLongGraphAll;
tensor->type_ = kGetNextOutput; tensor->type_ = kGetNextOutput;
@ -282,8 +280,6 @@ void Somas::GetNextOutputProcess(const session::KernelGraph *graph) {
} }
} }
this->get_next_size_ = total_size;
MS_LOG(INFO) << "Special Tensor total size: GetNext Output " << total_size; MS_LOG(INFO) << "Special Tensor total size: GetNext Output " << total_size;
} }
@ -947,7 +943,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
// Solver info -- moved here because we set sizes to zero in ref node preprocessing (was before in GetSomasTensors()) // Solver info -- moved here because we set sizes to zero in ref node preprocessing (was before in GetSomasTensors())
MS_LOG(INFO) << "Start Loop to create solver info"; MS_LOG(INFO) << "Start Loop to create solver info";
for (auto tensor : tensors_list_) { for (auto tensor : tensors_list_) {
if (tensor->GetSolverTensorDesc() != nullptr && tensor->type_ != kGetNextOutput) { if (tensor->GetSolverTensorDesc() != nullptr) {
SomasSolverTensorDescPtr pSolverTensor = tensor->GetSolverTensorDesc(); SomasSolverTensorDescPtr pSolverTensor = tensor->GetSolverTensorDesc();
solver_tensor_desc_list_.insert( solver_tensor_desc_list_.insert(
std::pair<size_t, SomasSolverTensorDescPtr>(pSolverTensor->index_, pSolverTensor)); std::pair<size_t, SomasSolverTensorDescPtr>(pSolverTensor->index_, pSolverTensor));
@ -972,7 +968,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
// Update solver_tensor_desc offset to tensors list // Update solver_tensor_desc offset to tensors list
for (const auto &tensor : tensors_list_) { for (const auto &tensor : tensors_list_) {
tensor->SetOffset(get_next_size_); tensor->SetOffset();
} }
// Ref Node Postprocessing // Ref Node Postprocessing
@ -995,7 +991,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
MS_LOG(INFO) << "\nEnd Solving Postprocessing for Ref Node"; MS_LOG(INFO) << "\nEnd Solving Postprocessing for Ref Node";
// Set mem_offset_ value by solver result // Set mem_offset_ value by solver result
mem_offset_ = static_cast<size_t>(somas_solver_->GetMaxOffset()) + get_next_size_; mem_offset_ = static_cast<size_t>(somas_solver_->GetMaxOffset());
if (save_graphs_) { if (save_graphs_) {
std::string mem_pool_file_path = std::string mem_pool_file_path =

@ -81,9 +81,6 @@ class Somas {
// total Offset // total Offset
size_t mem_offset_; size_t mem_offset_;
// getnext op output size
size_t get_next_size_;
// Memory base addr // Memory base addr
uint8_t *mem_base_addr_{nullptr}; uint8_t *mem_base_addr_{nullptr};

@ -108,9 +108,9 @@ class SomasTensor {
bool IsGap() { return type_ == kGap; } bool IsGap() { return type_ == kGap; }
// Computing functions // Computing functions
void SetOffset(size_t start_offset = 0) { void SetOffset() {
if (aligned_size_ != 0 && type_ != kGetNextOutput) { if (aligned_size_ != 0) {
offset_ = start_offset + solver_tensor_desc_->offset_; offset_ = solver_tensor_desc_->offset_;
} }
} }
SomasSolverTensorDescPtr GetSolverTensorDesc(); SomasSolverTensorDescPtr GetSolverTensorDesc();

Loading…
Cancel
Save