|
|
|
@ -64,10 +64,6 @@ def auc(input, label, curve='ROC', num_thresholds=200):
|
|
|
|
|
topk_indices = helper.create_tmp_variable(dtype="int64")
|
|
|
|
|
topk_out, topk_indices = nn.topk(input, k=k)
|
|
|
|
|
auc_out = helper.create_tmp_variable(dtype="float32")
|
|
|
|
|
if correct is None:
|
|
|
|
|
correct = helper.create_tmp_variable(dtype="int64")
|
|
|
|
|
if total is None:
|
|
|
|
|
total = helper.create_tmp_variable(dtype="int64")
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type="accuracy",
|
|
|
|
|
inputs={
|
|
|
|
|