From 6943bb38274a51fbe79bf4171e2b7ede32dddd88 Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Mon, 30 Nov 2020 13:48:17 +0800 Subject: [PATCH] Fix openpose net eval bug --- model_zoo/official/cv/openpose/eval.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/cv/openpose/eval.py b/model_zoo/official/cv/openpose/eval.py index 8fa68486eb..212d32b339 100644 --- a/model_zoo/official/cv/openpose/eval.py +++ b/model_zoo/official/cv/openpose/eval.py @@ -211,7 +211,7 @@ def compute_connections(pafs, all_peaks, img_len, cfg): cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:] cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:] - if cand_a and cand_b: + if cand_a.shape[0] > 0 and cand_b.shape[0] > 0: candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg) connections = np.zeros((0, 3)) @@ -346,7 +346,7 @@ def detect(img, network): cv2.imwrite(save_path, heatmaps[i]*255) all_peaks = compute_peaks_from_heatmaps(heatmaps) - if not all_peaks: + if all_peaks.shape[0] == 0: return np.empty((0, len(JointType), 3)), np.empty(0) all_connections = compute_connections(pafs, all_peaks, map_w, params) subsets = grouping_key_points(all_connections, all_peaks, params) @@ -359,7 +359,7 @@ def detect(img, network): def draw_person_pose(orig_img, poses): orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB) - if not poses: + if poses.shape[0] == 0: return orig_img limb_colors = [ @@ -426,7 +426,7 @@ def _eval(): img_id = int((img_id.asnumpy())[0]) poses, scores = detect(img, network) - if poses: + if poses.shape[0] > 0: #print("got poses") for index, pose in enumerate(poses): data = dict()