infer2
fengjiayi 7 years ago
parent fb08e163cf
commit bbcf1ad263

@ -87,8 +87,13 @@ def memory_usage(program, batch_size):
continue continue
data_count = 1 data_count = 1
neg_dim_count = 0
for x in var.shape: for x in var.shape:
if x < 0: if x < 0:
if neg_dim_count >= 1:
raise ValueError("Var %s has more than one negtive dim."
% (var_name))
neg_dim_count += 1
data_count *= batch_size * (-x) data_count *= batch_size * (-x)
else: else:
data_count *= x data_count *= x

Loading…
Cancel
Save