|
|
@ -100,8 +100,10 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const {
|
|
|
|
int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const {
|
|
|
|
int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const {
|
|
|
|
auto *var_desc = GetVarDesc(var);
|
|
|
|
auto *var_desc = GetVarDesc(var);
|
|
|
|
auto shapes = var_desc->GetShape();
|
|
|
|
auto shapes = var_desc->GetShape();
|
|
|
|
|
|
|
|
auto sizeof_dtype = static_cast<int64_t>(SizeOfType(var_desc->GetDataType()));
|
|
|
|
return std::accumulate(shapes.begin(), shapes.end(), static_cast<int64_t>(1),
|
|
|
|
return std::accumulate(shapes.begin(), shapes.end(), static_cast<int64_t>(1),
|
|
|
|
std::multiplies<int64_t>());
|
|
|
|
std::multiplies<int64_t>()) *
|
|
|
|
|
|
|
|
sizeof_dtype;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
|
|
|
|
void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
|
|
|
|