|
|
|
@ -326,9 +326,12 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
|
|
|
|
if idx >= len(valid_dataloader):
|
|
|
|
|
break
|
|
|
|
|
images = batch[0]
|
|
|
|
|
others = batch[-4:]
|
|
|
|
|
start = time.time()
|
|
|
|
|
preds = model(images, others)
|
|
|
|
|
if "SRN" in str(model.head):
|
|
|
|
|
others = batch[-4:]
|
|
|
|
|
preds = model(images, others)
|
|
|
|
|
else:
|
|
|
|
|
preds = model(images)
|
|
|
|
|
|
|
|
|
|
batch = [item.numpy() for item in batch]
|
|
|
|
|
# Obtain usable results from post-processing methods
|
|
|
|
|