You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
5.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 在视频/网络摄像头上运行检测器
# 不在batch上迭代而是在视频的帧上迭代
from __future__ import division
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import cv2
from util import *
import argparse
import os
import os.path as osp
from darknet import Darknet
import pickle as pkl
import pandas as pd
import random
# 命令行参数
def arg_parse():
parser = argparse.ArgumentParser(description='YOLO v3 Detection Module')
# images用于指定输入图像或图像目录
parser.add_argument("--images", dest = 'images', help =
"Image / Directory containing images to perform detection upon",
default = "imgs", type = str)
# det保存检测结果的目录
parser.add_argument("--det", dest = 'det', help =
"Image / Directory to store detections to",
default = "det", type = str)
# batch大小
parser.add_argument("--bs", dest = "bs", help = "Batch size", default = 1)
# objectness置信度
parser.add_argument("--confidence", dest = "confidence", help = "Object Confidence to filter predictions", default = 0.5)
# NMS阈值
parser.add_argument("--nms_thresh", dest = "nms_thresh", help = "NMS Threshhold", default = 0.4)
# cfg替代配置文件
parser.add_argument("--cfg", dest = 'cfgfile', help =
"Config file",
default = "cfg/yolov3.cfg", type = str)
parser.add_argument("--weights", dest = 'weightsfile', help =
"weightsfile",
default = "yolov3.weights", type = str)
# reso输入图像的分辨率可用于在速度与准确度之间的权衡
parser.add_argument("--reso", dest = 'reso', help =
"Input resolution of the network. Increase to increase accuracy. Decrease to increase speed",
default = "416", type = str)
return parser.parse_args()
if __name__ == '__main__':
args = arg_parse()
images = args.images
batch_size = int(args.bs)
confidence = float(args.confidence)
nms_thesh = float(args.nms_thresh)
start = 0
CUDA = torch.cuda.is_available()
num_classes = 80 # COCO数据集中目标的名称
classes = load_classes("data/coco.names")
# 初始化网络,加载权重
print("正在加载网络QAQ")
model = Darknet(args.cfgfile)
model.load_weights(args.weightsfile)
print("网络加载成功QvQ")
model.net_info["height"] = args.reso
inp_dim = int(model.net_info["height"])
assert inp_dim % 32 == 0
assert inp_dim > 32
# GPU加速
if CUDA:
model.cuda()
# 模型评估
model.eval()
# 绘制边界框:从colors中随机选颜色绘制矩形框
# 边界框左上角创建一个填充后的矩形,写入该框位置检测到的目标的类别
def write(x, results):
c1 = tuple(x[1:3].int())
c2 = tuple(x[3:5].int())
img = results # 仅处理一帧
cls = int(x[-1])
color = random.choice(colors)
label = "{0}".format(classes[cls])
cv2.rectangle(img, c1, c2, color, 1)
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4
cv2.rectangle(img, c1, c2, color, -1) # -1表示填充的矩形
cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,225,225], 1)
return img
# 检测阶段
videofile = "workingcell.mp4" # 加载视频文件路径
cap = cv2.VideoCapture(videofile) # 用OpenCV打开视频/相机流
#assert cap.isOpened(), '未找到需要检测视频TAT'
frames = 0 # 帧的数量
start = time.time()
# 在帧上迭代,一次处理一帧
while cap.isOpened():
ret, frame = cap.read()
if ret:
img = prep_image(frame, inp_dim)
im_dim = frame.shape[1], frame.shape[0]
im_dim = torch.FloatTensor(im_dim).repeat(1,2)
if CUDA:
im_dim = im_dim.cuda()
img = img.cuda()
output = model(Variable(img, volatile=True), CUDA)
output = write_results(output, confidence, num_classes, nms_conf=nms_thesh)
if type(output) == int:
frames += 1
print("视频的FPS为 {:5.4f}".format(frames / (time.time() - start)))
# 使用cv2.imshow展示画有边界框的帧
cv2.imshow("", frame)
key = cv2.waitKey(1)
# 用户按q就会终止视频(代码中断循环)
if key & 0xFF == ord('q'):
break
continue
output[:,1:5] = torch.clamp(output[:,1:5], 0.0, float(inp_dim))
im_dim = im_dim.repeat(output.size(0), 1)/inp_dim
output[:,1:5] *= im_dim
classes = load_classes('data/coco.names')
colors = pkl.load(open("pallete", "rb"))
list(map(lambda x: write(x, frame), output))
cv2.imshow("", frame)
key = cv2.waitKey(1)
if key & 0xFF == ord('q'):
break
frames += 1
print(time.time() - start)
print("视频的FPS为 {:5.4f}".format(frames / (time.time() - start)))
else:
break