|
|
|
@ -70,23 +70,37 @@ def memory_usage(program, batch_size):
|
|
|
|
|
if not isinstance(program, Program):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Calculating Memory Usage requires Program as its Parameter."
|
|
|
|
|
"But you passed in %s" % (type(prgram)))
|
|
|
|
|
"But you passed in %s" % (type(program)))
|
|
|
|
|
if batch_size <= 0:
|
|
|
|
|
raise ValueError("The batch size need to be positive.")
|
|
|
|
|
|
|
|
|
|
# Get the var_name list of first block and calculate
|
|
|
|
|
total_memory = 0.0
|
|
|
|
|
for var in six.itervalues(program.global_block().vars):
|
|
|
|
|
data_count = 1
|
|
|
|
|
for x in var.shape:
|
|
|
|
|
if x == -1:
|
|
|
|
|
data_count *= batch_size
|
|
|
|
|
else:
|
|
|
|
|
data_count *= x
|
|
|
|
|
var_memory = data_count * dtype_to_size[var.dtype]
|
|
|
|
|
if DEBUG:
|
|
|
|
|
print("%s memory usage: %d" % (var.name, var_memory))
|
|
|
|
|
total_memory += var_memory
|
|
|
|
|
processed_var_names = set()
|
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
|
for var_name in op.output_arg_names:
|
|
|
|
|
if var_name in processed_var_names:
|
|
|
|
|
continue
|
|
|
|
|
processed_var_names.add(var_name)
|
|
|
|
|
var = program.global_block().vars[var_name]
|
|
|
|
|
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
data_count = 1
|
|
|
|
|
neg_dim_count = 0
|
|
|
|
|
for x in var.shape:
|
|
|
|
|
if x < 0:
|
|
|
|
|
if neg_dim_count >= 1:
|
|
|
|
|
raise ValueError("Var %s has more than one negtive dim."
|
|
|
|
|
% (var_name))
|
|
|
|
|
neg_dim_count += 1
|
|
|
|
|
data_count *= batch_size * (-x)
|
|
|
|
|
else:
|
|
|
|
|
data_count *= x
|
|
|
|
|
var_memory = data_count * dtype_to_size[var.dtype]
|
|
|
|
|
if DEBUG:
|
|
|
|
|
print("%s memory usage: %d" % (var.name, var_memory))
|
|
|
|
|
total_memory += var_memory
|
|
|
|
|
if DEBUG:
|
|
|
|
|
print("total memory usage: %.2f" % (total_memory))
|
|
|
|
|
|
|
|
|
|