|
|
|
@ -14,12 +14,14 @@
|
|
|
|
|
"""
|
|
|
|
|
This module privides a memory usage calculate function for user.
|
|
|
|
|
The purpose of this API is to allow users to estimate memory usage of
|
|
|
|
|
a program under a special batch size, then user can set appropriate
|
|
|
|
|
batch size to fully utilize a GPU.
|
|
|
|
|
a program under a special batch size, then user can set appropriate
|
|
|
|
|
batch size to fully utilize a GPU.
|
|
|
|
|
|
|
|
|
|
This API is still under active development and may change drastically.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import six
|
|
|
|
|
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..framework import Program, Variable
|
|
|
|
|
|
|
|
|
@ -45,15 +47,15 @@ def memory_usage(program, batch_size):
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
program(Program): The current Program.
|
|
|
|
|
batch_size(int): The current input data batch_size.
|
|
|
|
|
|
|
|
|
|
batch_size(int): The current input data batch_size.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
min_total_memory(float): the estimate memory usage lower bound.
|
|
|
|
|
max_total_memory(float): the estimate memory usage upper bound.
|
|
|
|
|
unit_str(string): the unit of estimate usage result.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>> import paddle.fluid as fluid
|
|
|
|
|
>>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
|
|
|
|
|
fluid.default_main_program(), batch_size=10)
|
|
|
|
@ -72,7 +74,7 @@ def memory_usage(program, batch_size):
|
|
|
|
|
|
|
|
|
|
# Get the var_name list of first block and calculate
|
|
|
|
|
total_memory = 0.0
|
|
|
|
|
for var in program.global_block().vars.itervalues():
|
|
|
|
|
for var in six.itervalues(program.global_block().vars):
|
|
|
|
|
data_count = 1
|
|
|
|
|
for x in var.shape:
|
|
|
|
|
if x == -1:
|
|
|
|
@ -81,10 +83,10 @@ def memory_usage(program, batch_size):
|
|
|
|
|
data_count *= x
|
|
|
|
|
var_memory = data_count * dtype_to_size[var.dtype]
|
|
|
|
|
if DEBUG:
|
|
|
|
|
print "%s memory usage: %d" % (var.name, var_memory)
|
|
|
|
|
print("%s memory usage: %d" % (var.name, var_memory))
|
|
|
|
|
total_memory += var_memory
|
|
|
|
|
if DEBUG:
|
|
|
|
|
print "total memory usage: %.2f" % (total_memory)
|
|
|
|
|
print("total memory usage: %.2f" % (total_memory))
|
|
|
|
|
|
|
|
|
|
# Convert appropriate unit
|
|
|
|
|
unit_str = "B"
|
|
|
|
|