|
|
@ -87,8 +87,13 @@ def memory_usage(program, batch_size):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
data_count = 1
|
|
|
|
data_count = 1
|
|
|
|
|
|
|
|
neg_dim_count = 0
|
|
|
|
for x in var.shape:
|
|
|
|
for x in var.shape:
|
|
|
|
if x < 0:
|
|
|
|
if x < 0:
|
|
|
|
|
|
|
|
if neg_dim_count >= 1:
|
|
|
|
|
|
|
|
raise ValueError("Var %s has more than one negtive dim."
|
|
|
|
|
|
|
|
% (var_name))
|
|
|
|
|
|
|
|
neg_dim_count += 1
|
|
|
|
data_count *= batch_size * (-x)
|
|
|
|
data_count *= batch_size * (-x)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
data_count *= x
|
|
|
|
data_count *= x
|
|
|
|