Set stop_gradient=True for some variables in SSD API. (#9396)

helinwang-patch-1
qingqing01 7 years ago committed by GitHub
parent e0b5691e41
commit 123cf165fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -134,6 +134,7 @@ def detection_output(loc,
scores = nn.softmax(input=scores)
scores = ops.reshape(x=scores, shape=old_shape)
scores = nn.transpose(scores, perm=[0, 2, 1])
scores.stop_gradient = True
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
helper.append_op(
type="multiclass_nms",
@ -148,6 +149,7 @@ def detection_output(loc,
'score_threshold': score_threshold,
'nms_eta': 1.0
})
nmsed_outs.stop_gradient = True
return nmsed_outs
@ -837,4 +839,6 @@ def multi_box_head(inputs,
mbox_locs_concat = tensor.concat(mbox_locs, axis=1)
mbox_confs_concat = tensor.concat(mbox_confs, axis=1)
box.stop_gradient = True
var.stop_gradient = True
return mbox_locs_concat, mbox_confs_concat, box, var

Loading…
Cancel
Save