|
|
|
@ -20,6 +20,8 @@ batch size to fully utilize a GPU.
|
|
|
|
|
This API is still under active development and may change drastically.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import six
|
|
|
|
|
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..framework import Program, Variable
|
|
|
|
|
|
|
|
|
@ -72,7 +74,7 @@ def memory_usage(program, batch_size):
|
|
|
|
|
|
|
|
|
|
# Get the var_name list of first block and calculate
|
|
|
|
|
total_memory = 0.0
|
|
|
|
|
for var in program.global_block().vars.itervalues():
|
|
|
|
|
for var in six.itervalues(program.global_block().vars):
|
|
|
|
|
data_count = 1
|
|
|
|
|
for x in var.shape:
|
|
|
|
|
if x == -1:
|
|
|
|
@ -81,10 +83,10 @@ def memory_usage(program, batch_size):
|
|
|
|
|
data_count *= x
|
|
|
|
|
var_memory = data_count * dtype_to_size[var.dtype]
|
|
|
|
|
if DEBUG:
|
|
|
|
|
print "%s memory usage: %d" % (var.name, var_memory)
|
|
|
|
|
print("%s memory usage: %d" % (var.name, var_memory))
|
|
|
|
|
total_memory += var_memory
|
|
|
|
|
if DEBUG:
|
|
|
|
|
print "total memory usage: %.2f" % (total_memory)
|
|
|
|
|
print("total memory usage: %.2f" % (total_memory))
|
|
|
|
|
|
|
|
|
|
# Convert appropriate unit
|
|
|
|
|
unit_str = "B"
|
|
|
|
|