commit
8cdb42c2b3
@ -0,0 +1,179 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/precision_recall_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class PrecisionRecallOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("MaxProbs"),
|
||||
"Input(MaxProbs) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Indices"),
|
||||
"Input(Indices) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
||||
"Input(Labels) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"),
|
||||
"Output(BatchMetrics) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("AccumMetrics"),
|
||||
"Output(AccumMetrics) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"),
|
||||
"Output(AccumStatesInfo) should not be null.");
|
||||
|
||||
int64_t cls_num =
|
||||
static_cast<int64_t>(ctx->Attrs().Get<int>("class_number"));
|
||||
auto max_probs_dims = ctx->GetInputDim("MaxProbs");
|
||||
auto labels_dims = ctx->GetInputDim("Labels");
|
||||
|
||||
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
|
||||
"Each instance contains one max probability, so the "
|
||||
"shape of Input(MaxProbs) should be [batch_size, 1].");
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims,
|
||||
"The shape of Input(Indices) should be [batch_size, 1].");
|
||||
PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0],
|
||||
"The 1st dimension of Input(MaxProbs) and "
|
||||
"Input(Labels) both are batch_size and the shape should "
|
||||
"be the same.");
|
||||
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
|
||||
"The 2nd dimension of Input(Labels) contains instance "
|
||||
"label and the shape should be equal to 1.");
|
||||
if (ctx->HasInput("Weights")) {
|
||||
auto weights_dims = ctx->GetInputDim("Weights");
|
||||
PADDLE_ENFORCE_EQ(weights_dims,
|
||||
framework::make_ddim({max_probs_dims[0], 1}),
|
||||
"The shape of Input(Weights) should be "
|
||||
"[batch_size, 1].");
|
||||
}
|
||||
if (ctx->HasInput("StatesInfo")) {
|
||||
auto states_dims = ctx->GetInputDim("StatesInfo");
|
||||
PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
|
||||
"The shape of Input(StatesInfo) should be "
|
||||
"[class_number, 4].");
|
||||
}
|
||||
|
||||
// Layouts of BatchMetrics and AccumMetrics both are:
|
||||
// [
|
||||
// macro average precision, macro average recall, macro average F1 score,
|
||||
// micro average precision, micro average recall, micro average F1 score
|
||||
// ]
|
||||
ctx->SetOutputDim("BatchMetrics", {6});
|
||||
ctx->SetOutputDim("AccumMetrics", {6});
|
||||
// Shape of AccumStatesInfo is [class_number, 4]
|
||||
// The layout of each row is:
|
||||
// [ TP, FP, TN, FN ]
|
||||
ctx->SetOutputDim("AccumStatesInfo", {cls_num, 4});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::DataType IndicateDataType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type());
|
||||
}
|
||||
};
|
||||
|
||||
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
PrecisionRecallOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("MaxProbs",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
|
||||
"where N is the batch size. Each row contains the max probability "
|
||||
"of an instance which computed by the previous top_k (k=1) "
|
||||
"operator.");
|
||||
AddInput("Indices",
|
||||
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
|
||||
"where N is the batch size. Each row contains the corresponding "
|
||||
"index which computed by the previous top_k (k=1) operator.");
|
||||
AddInput("Labels",
|
||||
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
|
||||
"where N is the batch size. Each element is a label and the "
|
||||
"value should be in [0, class_number - 1].");
|
||||
AddInput("Weights",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
|
||||
"where N is the batch size. This input is optional. If provided, "
|
||||
"weight of instance would be considered when computing metrics.")
|
||||
.AsDispensable();
|
||||
AddInput("StatesInfo",
|
||||
"(Tensor, default Tensor<int>), a 2-D tensor with shape D x 4, "
|
||||
"where D is the number of classes. This input is optional. If "
|
||||
"provided, current state will be accumulated to this state and "
|
||||
"the accumulation state will be as the output state.")
|
||||
.AsDispensable();
|
||||
AddOutput("BatchMetrics",
|
||||
"(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
|
||||
"This output tensor contains metrics for current batch data."
|
||||
"The layout is [macro average precision, macro average recall, "
|
||||
"macro f1 score, micro average precision, micro average recall, "
|
||||
"micro f1 score]");
|
||||
AddOutput("AccumMetrics",
|
||||
"(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
|
||||
"This output tensor contains metrics for accumulated data."
|
||||
"The layout is [macro average precision, macro average recall, "
|
||||
"macro f1 score, micro average precision, micro average recall, "
|
||||
"micro f1 score]");
|
||||
AddOutput("AccumStatesInfo",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape D x 4, "
|
||||
"where D is equal to class number. This output tensor contains "
|
||||
"accumulated state variables used to compute metrics. The layout "
|
||||
"for each class is [true positives, false positives, "
|
||||
"true negatives, false negatives].");
|
||||
AddAttr<int>("class_number", "Number of classes to be evaluated.");
|
||||
AddComment(R"DOC(
|
||||
When given 'Input(Indices)' and 'Input(Labels)', this operator can be used
|
||||
to compute various metrics including:
|
||||
- macro average precision
|
||||
- macro average recall
|
||||
- macro f1 score
|
||||
- micro average precision
|
||||
- micro average recall
|
||||
- micro f1 score
|
||||
|
||||
To compute the above metrics, we need to do statistics for true positives,
|
||||
false positives and false negatives. Here count of true negatives is not
|
||||
necessary, but counting it may provide potential usage and the cost is
|
||||
trivial, so the operator also provides count of true negatives.
|
||||
|
||||
We define state as a 2-D tensor with shape [class_number, 4]. Each row of a
|
||||
state contains statistic variables for corresponding class. Layout of each row
|
||||
is: TP(true positives), FP(false positives), TN(true negatives),
|
||||
FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be
|
||||
calculated by given weight instead of instance count.
|
||||
|
||||
This operator also supports metrics computing for cross-batch situation. To
|
||||
achieve this, 'Input(StatesInfo)' should be provided. State of current batch
|
||||
data will be accumulated to 'Input(StatesInfo)' and 'Output(AccumStatesInfo)'
|
||||
is the accumulation state.
|
||||
|
||||
'Output(BatchMetrics)' is metrics of current batch data while
|
||||
'Output(AccumStatesInfo)' is metrics of accumulation data.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp,
|
||||
ops::PrecisionRecallOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
precision_recall,
|
||||
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,161 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
enum StateVariable { TP = 0, FP, TN, FN };
|
||||
|
||||
template <typename Place, typename T>
|
||||
class PrecisionRecallKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in0 = ctx.Input<Tensor>("Indices");
|
||||
auto* in1 = ctx.Input<Tensor>("Labels");
|
||||
auto* in2 = ctx.Input<Tensor>("Weights");
|
||||
auto* in3 = ctx.Input<Tensor>("StatesInfo");
|
||||
auto* out0 = ctx.Output<Tensor>("BatchMetrics");
|
||||
auto* out1 = ctx.Output<Tensor>("AccumMetrics");
|
||||
auto* out2 = ctx.Output<Tensor>("AccumStatesInfo");
|
||||
|
||||
const int* ids_data = in0->data<int>();
|
||||
const int* labels_data = in1->data<int>();
|
||||
size_t cls_num = static_cast<size_t>(ctx.Attr<int>("class_number"));
|
||||
const T* weights_data = in2 ? in2->data<T>() : nullptr;
|
||||
const T* states_data = in3 ? in3->data<T>() : nullptr;
|
||||
double* batch_metrics_data = out0->mutable_data<double>(ctx.GetPlace());
|
||||
double* accum_metrics_data = out1->mutable_data<double>(ctx.GetPlace());
|
||||
out2->mutable_data<T>(ctx.GetPlace());
|
||||
auto accum_states = EigenMatrix<T>::From(*out2);
|
||||
accum_states.setZero();
|
||||
T* accum_states_data = out2->data<T>();
|
||||
|
||||
size_t sample_num = in0->dims()[0];
|
||||
size_t state_var_num = 4; // TP FP TN FN
|
||||
|
||||
// get states info for current batch
|
||||
for (size_t i = 0; i < sample_num; ++i) {
|
||||
size_t idx = ids_data[i];
|
||||
size_t label = labels_data[i];
|
||||
|
||||
PADDLE_ENFORCE(idx >= 0 && idx < cls_num,
|
||||
"Class index of each instance should be in "
|
||||
"[0, class_number).");
|
||||
PADDLE_ENFORCE(label >= 0 && label < cls_num,
|
||||
"Label of each instance should be in [0, class_number).");
|
||||
|
||||
T w = weights_data ? weights_data[i] : 1.0;
|
||||
if (idx == label) {
|
||||
accum_states_data[idx * state_var_num + TP] += w;
|
||||
for (size_t j = 0; j < cls_num; ++j) {
|
||||
accum_states_data[j * state_var_num + TN] += w;
|
||||
}
|
||||
accum_states_data[idx * state_var_num + TN] -= w;
|
||||
} else {
|
||||
accum_states_data[label * state_var_num + FN] += w;
|
||||
accum_states_data[idx * state_var_num + FP] += w;
|
||||
for (size_t j = 0; j < cls_num; ++j) {
|
||||
accum_states_data[j * state_var_num + TN] += w;
|
||||
}
|
||||
accum_states_data[idx * state_var_num + TN] -= w;
|
||||
accum_states_data[label * state_var_num + TN] -= w;
|
||||
}
|
||||
}
|
||||
|
||||
ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num,
|
||||
cls_num);
|
||||
|
||||
if (states_data) {
|
||||
for (size_t i = 0; i < cls_num; ++i) {
|
||||
for (size_t j = 0; j < state_var_num; ++j) {
|
||||
size_t idx = i * state_var_num + j;
|
||||
accum_states_data[idx] += states_data[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num,
|
||||
cls_num);
|
||||
}
|
||||
|
||||
// expose to be reused
|
||||
static inline T CalcPrecision(T tp_count, T fp_count) {
|
||||
if (tp_count > 0.0 || fp_count > 0.0) {
|
||||
return tp_count / (tp_count + fp_count);
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
static inline T CalcRecall(T tp_count, T fn_count) {
|
||||
if (tp_count > 0.0 || fn_count > 0.0) {
|
||||
return tp_count / (tp_count + fn_count);
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
static inline T CalcF1Score(T precision, T recall) {
|
||||
if (precision > 0.0 || recall > 0.0) {
|
||||
return 2 * precision * recall / (precision + recall);
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeMetrics(const T* states_data, double* metrics_data,
|
||||
size_t state_var_num, size_t cls_num) const {
|
||||
T total_tp_count = 0;
|
||||
T total_fp_count = 0;
|
||||
T total_fn_count = 0;
|
||||
T macro_avg_precision = 0.0;
|
||||
T macro_avg_recall = 0.0;
|
||||
|
||||
for (size_t i = 0; i < cls_num; ++i) {
|
||||
T tp_count = states_data[i * state_var_num + TP];
|
||||
T fp_count = states_data[i * state_var_num + FP];
|
||||
T fn_count = states_data[i * state_var_num + FN];
|
||||
total_tp_count += tp_count;
|
||||
total_fp_count += fp_count;
|
||||
total_fn_count += fn_count;
|
||||
macro_avg_precision += CalcPrecision(tp_count, fp_count);
|
||||
macro_avg_recall += CalcRecall(tp_count, fn_count);
|
||||
}
|
||||
macro_avg_precision /= cls_num;
|
||||
macro_avg_recall /= cls_num;
|
||||
T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall);
|
||||
|
||||
T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count);
|
||||
T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count);
|
||||
T micro_f1_score = CalcF1Score(micro_avg_precision, micro_avg_recall);
|
||||
|
||||
// fill metrics data
|
||||
metrics_data[0] = macro_avg_precision;
|
||||
metrics_data[1] = macro_avg_recall;
|
||||
metrics_data[2] = macro_f1_score;
|
||||
metrics_data[3] = micro_avg_precision;
|
||||
metrics_data[4] = micro_avg_recall;
|
||||
metrics_data[5] = micro_f1_score;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,173 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def calc_precision(tp_count, fp_count):
|
||||
if tp_count > 0.0 or fp_count > 0.0:
|
||||
return tp_count / (tp_count + fp_count)
|
||||
return 1.0
|
||||
|
||||
|
||||
def calc_recall(tp_count, fn_count):
|
||||
if tp_count > 0.0 or fn_count > 0.0:
|
||||
return tp_count / (tp_count + fn_count)
|
||||
return 1.0
|
||||
|
||||
|
||||
def calc_f1_score(precision, recall):
|
||||
if precision > 0.0 or recall > 0.0:
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_states(idxs, labels, cls_num, weights=None):
|
||||
ins_num = idxs.shape[0]
|
||||
# TP FP TN FN
|
||||
states = np.zeros((cls_num, 4)).astype('float32')
|
||||
for i in xrange(ins_num):
|
||||
w = weights[i] if weights is not None else 1.0
|
||||
idx = idxs[i][0]
|
||||
label = labels[i][0]
|
||||
if idx == label:
|
||||
states[idx][0] += w
|
||||
for j in xrange(cls_num):
|
||||
states[j][2] += w
|
||||
states[idx][2] -= w
|
||||
else:
|
||||
states[label][3] += w
|
||||
states[idx][1] += w
|
||||
for j in xrange(cls_num):
|
||||
states[j][2] += w
|
||||
states[label][2] -= w
|
||||
states[idx][2] -= w
|
||||
return states
|
||||
|
||||
|
||||
def compute_metrics(states, cls_num):
|
||||
total_tp_count = 0.0
|
||||
total_fp_count = 0.0
|
||||
total_fn_count = 0.0
|
||||
macro_avg_precision = 0.0
|
||||
macro_avg_recall = 0.0
|
||||
for i in xrange(cls_num):
|
||||
total_tp_count += states[i][0]
|
||||
total_fp_count += states[i][1]
|
||||
total_fn_count += states[i][3]
|
||||
macro_avg_precision += calc_precision(states[i][0], states[i][1])
|
||||
macro_avg_recall += calc_recall(states[i][0], states[i][3])
|
||||
metrics = []
|
||||
macro_avg_precision /= cls_num
|
||||
macro_avg_recall /= cls_num
|
||||
metrics.append(macro_avg_precision)
|
||||
metrics.append(macro_avg_recall)
|
||||
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
|
||||
micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
|
||||
metrics.append(micro_avg_precision)
|
||||
micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
|
||||
metrics.append(micro_avg_recall)
|
||||
metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
|
||||
return np.array(metrics).astype('float32')
|
||||
|
||||
|
||||
class TestPrecisionRecallOp_0(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "precision_recall"
|
||||
ins_num = 64
|
||||
cls_num = 10
|
||||
max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
|
||||
idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
|
||||
(ins_num, 1)).astype('int32')
|
||||
labels = np.random.choice(xrange(cls_num), ins_num).reshape(
|
||||
(ins_num, 1)).astype('int32')
|
||||
states = get_states(idxs, labels, cls_num)
|
||||
metrics = compute_metrics(states, cls_num)
|
||||
|
||||
self.attrs = {'class_number': cls_num}
|
||||
|
||||
self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels}
|
||||
|
||||
self.outputs = {
|
||||
'BatchMetrics': metrics,
|
||||
'AccumMetrics': metrics,
|
||||
'AccumStatesInfo': states
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestPrecisionRecallOp_1(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "precision_recall"
|
||||
ins_num = 64
|
||||
cls_num = 10
|
||||
max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
|
||||
idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
|
||||
(ins_num, 1)).astype('int32')
|
||||
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
|
||||
labels = np.random.choice(xrange(cls_num), ins_num).reshape(
|
||||
(ins_num, 1)).astype('int32')
|
||||
|
||||
states = get_states(idxs, labels, cls_num, weights)
|
||||
metrics = compute_metrics(states, cls_num)
|
||||
|
||||
self.attrs = {'class_number': cls_num}
|
||||
|
||||
self.inputs = {
|
||||
'MaxProbs': max_probs,
|
||||
'Indices': idxs,
|
||||
'Labels': labels,
|
||||
'Weights': weights
|
||||
}
|
||||
|
||||
self.outputs = {
|
||||
'BatchMetrics': metrics,
|
||||
'AccumMetrics': metrics,
|
||||
'AccumStatesInfo': states
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestPrecisionRecallOp_2(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "precision_recall"
|
||||
ins_num = 64
|
||||
cls_num = 10
|
||||
max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
|
||||
idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
|
||||
(ins_num, 1)).astype('int32')
|
||||
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
|
||||
labels = np.random.choice(xrange(cls_num), ins_num).reshape(
|
||||
(ins_num, 1)).astype('int32')
|
||||
states = np.random.randint(0, 30, (cls_num, 4)).astype('float32')
|
||||
|
||||
accum_states = get_states(idxs, labels, cls_num, weights)
|
||||
batch_metrics = compute_metrics(accum_states, cls_num)
|
||||
accum_states += states
|
||||
accum_metrics = compute_metrics(accum_states, cls_num)
|
||||
|
||||
self.attrs = {'class_number': cls_num}
|
||||
|
||||
self.inputs = {
|
||||
'MaxProbs': max_probs,
|
||||
'Indices': idxs,
|
||||
'Labels': labels,
|
||||
'Weights': weights,
|
||||
'StatesInfo': states
|
||||
}
|
||||
|
||||
self.outputs = {
|
||||
'BatchMetrics': batch_metrics,
|
||||
'AccumMetrics': accum_metrics,
|
||||
'AccumStatesInfo': accum_states
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue