Add mean IOU op. (#10519)
* Add mean_iou op. * Add unitest for mean iou op. * Add optional collections of confusion matrix and mean_iou. * Fix cuda kernel. * Refine code. 1. Merge computing in GPU to two kernel. 2. Use wrong array and correct array instead of confusion matrix. * Add python api and fix cuda kernel. * Fix comments. * Small fix. * Small fix.wangkuiyi-patch-1
parent
f790b96d6f
commit
6fcdb240fa
@ -0,0 +1,110 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/mean_iou_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class MeanIoUOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Predictions"),
|
||||
"Input (Predictions) of MeanIoU op should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
||||
"Input (labels) of MeanIoU op should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("OutMeanIou"),
|
||||
"Output (OutMeanIou) of MeanIoU op should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("OutWrong"),
|
||||
"Output (OutWrong) of MeanIoU op should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("OutCorrect"),
|
||||
"Output (OutWrong) of MeanIoU op should not be null.");
|
||||
|
||||
int64_t num_classes =
|
||||
static_cast<int64_t>(ctx->Attrs().Get<int>("num_classes"));
|
||||
|
||||
ctx->SetOutputDim("OutMeanIou", {1});
|
||||
ctx->SetOutputDim("OutWrong", {num_classes});
|
||||
ctx->SetOutputDim("OutCorrect", {num_classes});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("Predictions")->type()),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class MeanIoUOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Predictions",
|
||||
"(Tensor), A Tensor of prediction results for semantic labels"
|
||||
" with type int32 or int64. The rank should be greater than 1.");
|
||||
AddInput(
|
||||
"Labels",
|
||||
"(Tensor), A Tensor of ground truth labels with type int32 or int64."
|
||||
"Its shape should be the same as Input(Predictions).");
|
||||
AddInput("InWrongs",
|
||||
"(vector<Tensor>), A list of Tensor with shape "
|
||||
"[num_classes]. They are used to collect wrong number among "
|
||||
"batches. Empty list is also valid here.")
|
||||
.AsDuplicable()
|
||||
.AsDispensable();
|
||||
AddInput(
|
||||
"InCorrects",
|
||||
"(vector<Tensor>), A list of Tensor with shape "
|
||||
"[num_classes]. They are used to collect correct number among batches. "
|
||||
"Empty list is also valid here.")
|
||||
.AsDuplicable()
|
||||
.AsDispensable();
|
||||
AddInput("InMeanIou",
|
||||
"(vector<Tensor>), A list of Tensor that Output(mean_iou) should "
|
||||
"be added to. Empty list is also valid here.")
|
||||
.AsDuplicable()
|
||||
.AsDispensable();
|
||||
AddOutput("OutMeanIou",
|
||||
"(vector<Tensor>), A Tensor representing the"
|
||||
" mean intersection-over-union with shape [1].");
|
||||
AddOutput("OutWrong", "(Tensor), A Tensor with shape [num_classes]. ");
|
||||
AddOutput("OutCorrect", "(Tensor), A Tensor with shape [num_classes]. ");
|
||||
AddAttr<int>("num_classes", "(int), The possible number of labels.");
|
||||
|
||||
AddComment(R"DOC(
|
||||
mean-IOU Operator.
|
||||
Mean Intersection-Over-Union is a common evaluation metric for
|
||||
semantic image segmentation, which first computes the IOU for each
|
||||
semantic class and then computes the average over classes.
|
||||
IOU is defined as follows:
|
||||
IOU = true_positive / (true_positive + false_positive + false_negative).
|
||||
It is based on pixel level area while "IOU Similarity Operator"
|
||||
is based on area of rectangle.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(mean_iou, ops::MeanIoUOp, ops::MeanIoUOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(mean_iou, ops::MeanIoUKernel<int>,
|
||||
ops::MeanIoUKernel<int32_t>,
|
||||
ops::MeanIoUKernel<int64_t>);
|
@ -0,0 +1,164 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/mean_iou_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__global__ void CountCUDAKernel(const int num_classes, const int count,
|
||||
const T* predictions, const T* labels,
|
||||
int* wrong, int* correct) {
|
||||
extern __shared__ int blcok_cache[];
|
||||
int* wrong_c = blcok_cache;
|
||||
int* correct_c = blcok_cache + num_classes;
|
||||
// init cache
|
||||
for (int i = threadIdx.x; i < num_classes * 2; i += blockDim.x) {
|
||||
blcok_cache[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
T pred;
|
||||
T label;
|
||||
CUDA_1D_KERNEL_LOOP(i, count) {
|
||||
pred = predictions[i];
|
||||
label = labels[i];
|
||||
if (pred == label) {
|
||||
atomicAdd(correct_c + pred, 1);
|
||||
} else {
|
||||
atomicAdd(wrong_c + pred, 1);
|
||||
atomicAdd(wrong_c + label, 1);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = threadIdx.x; i < num_classes; i += blockDim.x) {
|
||||
atomicAdd(wrong + i, wrong_c[i]);
|
||||
atomicAdd(correct + i, correct_c[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ComputeIoUCUDAKernel(const int num_classes, int* wrong,
|
||||
int* correct, float* ious, float* iou) {
|
||||
__shared__ int valid_count_c;
|
||||
if (threadIdx.x == 0) {
|
||||
valid_count_c = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
CUDA_1D_KERNEL_LOOP(i, num_classes) {
|
||||
int wrong_n = wrong[i];
|
||||
int correct_n = correct[i];
|
||||
int denominator = wrong_n + correct_n;
|
||||
if (denominator > 0) {
|
||||
atomicAdd(&valid_count_c, 1);
|
||||
ious[i] = static_cast<float>(correct_n) / denominator;
|
||||
} else {
|
||||
ious[i] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
float iou_sum = 0;
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
iou_sum += ious[i];
|
||||
}
|
||||
iou[0] += iou_sum / valid_count_c;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
// get input and output tensor
|
||||
auto* predictions = ctx.Input<Tensor>("Predictions");
|
||||
auto* labels = ctx.Input<Tensor>("Labels");
|
||||
auto* out_mean_iou = ctx.Output<Tensor>("OutMeanIou");
|
||||
auto* out_wrong = ctx.Output<Tensor>("OutWrong");
|
||||
auto* out_correct = ctx.Output<Tensor>("OutCorrect");
|
||||
int num_classes = static_cast<int>(ctx.Attr<int>("num_classes"));
|
||||
|
||||
// Get data ptr
|
||||
const T* predictions_data = predictions->data<T>();
|
||||
const T* labels_data = labels->data<T>();
|
||||
int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace());
|
||||
int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace());
|
||||
float* out_mean_iou_data =
|
||||
out_mean_iou->mutable_data<float>(ctx.GetPlace());
|
||||
|
||||
// Get Eigen tensor
|
||||
auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou);
|
||||
auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
|
||||
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
|
||||
|
||||
// Temporary tensor
|
||||
Tensor ious;
|
||||
float* ious_data = ious.mutable_data<float>(
|
||||
{static_cast<int64_t>(num_classes)}, ctx.GetPlace());
|
||||
auto ious_t = EigenTensor<float, 1>::From(ious);
|
||||
|
||||
// Init out_wrong, out_correct and out_mean_iou
|
||||
out_wrong_t.device(place) = out_wrong_t.constant(0);
|
||||
out_correct_t.device(place) = out_correct_t.constant(0);
|
||||
out_mean_iou_t.device(place) = out_mean_iou_t.constant(0.0f);
|
||||
|
||||
// collect pre wrong, correct and mean_iou
|
||||
auto in_mean_ious = ctx.MultiInput<Tensor>("InMeanIou");
|
||||
for (int i = 0; i < in_mean_ious.size(); ++i) {
|
||||
out_mean_iou_t.device(place) +=
|
||||
EigenTensor<float, 1>::From(*in_mean_ious[i]);
|
||||
}
|
||||
auto in_wrongs = ctx.MultiInput<Tensor>("InWrongs");
|
||||
for (int i = 0; i < in_wrongs.size(); ++i) {
|
||||
out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]);
|
||||
}
|
||||
auto in_corrects = ctx.MultiInput<Tensor>("InCorrects");
|
||||
for (int i = 0; i < in_corrects.size(); ++i) {
|
||||
out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]);
|
||||
}
|
||||
// compute
|
||||
auto stream = ctx.cuda_device_context().stream();
|
||||
int block = PADDLE_CUDA_NUM_THREADS;
|
||||
int grid = (predictions->numel() + block - 1) / block;
|
||||
int cache_size = (num_classes * 2 + 1) * sizeof(int);
|
||||
CountCUDAKernel<T><<<grid, block, cache_size, stream>>>(
|
||||
num_classes, predictions->numel(), predictions_data, labels_data,
|
||||
out_wrong_data, out_correct_data);
|
||||
ctx.device_context().Wait();
|
||||
ComputeIoUCUDAKernel<<<1, block, 0, stream>>>(num_classes, out_wrong_data,
|
||||
out_correct_data, ious_data,
|
||||
out_mean_iou_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(mean_iou, ops::MeanIoUCUDAOpKernel<int>,
|
||||
ops::MeanIoUCUDAOpKernel<int64_t>,
|
||||
ops::MeanIoUCUDAOpKernel<int32_t>);
|
@ -0,0 +1,117 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T, int D, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
class MeanIoUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
|
||||
.eigen_device();
|
||||
// get input and output tensor
|
||||
auto* predictions = ctx.Input<Tensor>("Predictions");
|
||||
auto* labels = ctx.Input<Tensor>("Labels");
|
||||
auto* out_mean_iou = ctx.Output<Tensor>("OutMeanIou");
|
||||
auto* out_wrong = ctx.Output<Tensor>("OutWrong");
|
||||
auto* out_correct = ctx.Output<Tensor>("OutCorrect");
|
||||
int num_classes = static_cast<int>(ctx.Attr<int>("num_classes"));
|
||||
|
||||
// get data ptr
|
||||
const T* predictions_data = predictions->data<T>();
|
||||
const T* labels_data = labels->data<T>();
|
||||
float* out_mean_iou_data =
|
||||
out_mean_iou->mutable_data<float>(ctx.GetPlace());
|
||||
int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace());
|
||||
int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace());
|
||||
|
||||
// get eigen tensor
|
||||
auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou);
|
||||
auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
|
||||
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
|
||||
|
||||
// Tmp tensor
|
||||
Tensor denominator;
|
||||
Tensor valid_count;
|
||||
Tensor iou_sum;
|
||||
|
||||
// get data ptr of tmp tensor
|
||||
int* denominator_data = denominator.mutable_data<int>(
|
||||
{static_cast<int64_t>(num_classes)}, ctx.GetPlace());
|
||||
int* valid_count_data = valid_count.mutable_data<int>({1}, ctx.GetPlace());
|
||||
float* iou_sum_data = iou_sum.mutable_data<float>({1}, ctx.GetPlace());
|
||||
|
||||
// get eigen tensor of tmp tensor
|
||||
auto denominator_t = EigenTensor<int, 1>::From(denominator);
|
||||
auto valid_count_t = EigenTensor<int, 1>::From(valid_count);
|
||||
auto iou_sum_t = EigenTensor<float, 1>::From(iou_sum);
|
||||
|
||||
// init out_wrong, out_correct and out_mean_iou
|
||||
out_wrong_t = out_wrong_t.constant(0);
|
||||
out_correct_t = out_correct_t.constant(0);
|
||||
out_mean_iou_t = out_mean_iou_t.constant(0);
|
||||
|
||||
// collect pre wrong, correct and mean_iou
|
||||
auto in_mean_ious = ctx.MultiInput<Tensor>("InMeanIou");
|
||||
for (size_t i = 0; i < in_mean_ious.size(); ++i) {
|
||||
out_mean_iou_t.device(place) +=
|
||||
EigenTensor<float, 1>::From(*in_mean_ious[i]);
|
||||
}
|
||||
auto in_wrongs = ctx.MultiInput<Tensor>("InWrongs");
|
||||
for (size_t i = 0; i < in_wrongs.size(); ++i) {
|
||||
out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]);
|
||||
}
|
||||
auto in_corrects = ctx.MultiInput<Tensor>("InCorrects");
|
||||
for (size_t i = 0; i < in_corrects.size(); ++i) {
|
||||
out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]);
|
||||
}
|
||||
|
||||
// compute
|
||||
for (int64_t i = 0; i < predictions->numel(); ++i) {
|
||||
if (predictions_data[i] == labels_data[i]) {
|
||||
out_correct_data[predictions_data[i]] += 1;
|
||||
} else {
|
||||
out_wrong_data[labels_data[i]] += 1;
|
||||
out_wrong_data[predictions_data[i]] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
denominator_t = out_wrong_t + out_correct_t;
|
||||
valid_count_t =
|
||||
(denominator_t > denominator_t.constant(0.0f)).cast<int>().sum();
|
||||
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
if (denominator_data[i] == 0) {
|
||||
denominator_data[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
iou_sum_t =
|
||||
(out_correct_t.cast<float>() / denominator_t.cast<float>()).sum();
|
||||
out_mean_iou_data[0] += (iou_sum_data[0] / valid_count_data[0]);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,114 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def compute_mean_iou(predictions, labels, num_classes, in_wrongs, in_corrects,
|
||||
in_mean_ious):
|
||||
assert predictions.shape == labels.shape
|
||||
predictions = predictions.flatten()
|
||||
labels = labels.flatten()
|
||||
|
||||
out_wrong = np.zeros([num_classes]).astype("int32")
|
||||
for _, wrong in in_wrongs:
|
||||
out_wrong += wrong
|
||||
out_correct = np.zeros([num_classes]).astype("int32")
|
||||
for _, correct in in_corrects:
|
||||
out_correct += correct
|
||||
|
||||
for pred, label in zip(predictions, labels):
|
||||
if pred == label:
|
||||
out_correct[pred] += 1
|
||||
else:
|
||||
out_wrong[pred] += 1
|
||||
out_wrong[label] += 1
|
||||
|
||||
denominator = out_wrong + out_correct
|
||||
valid_count = (denominator != 0).sum()
|
||||
denominator = np.where(denominator > 0, denominator,
|
||||
np.ones(denominator.shape))
|
||||
mean_iou = (out_correct / denominator).sum() / valid_count
|
||||
|
||||
for _, in_mean_iou in in_mean_ious:
|
||||
mean_iou += in_mean_iou
|
||||
return mean_iou, out_wrong, out_correct
|
||||
|
||||
|
||||
class TestMeanIOUOp(OpTest):
|
||||
def setUp(self):
|
||||
self.config()
|
||||
self.op_type = "mean_iou"
|
||||
predictions = np.random.randint(0, self.num_classes,
|
||||
self.image_size).astype("int32")
|
||||
labels = np.random.randint(0, self.num_classes,
|
||||
self.image_size).astype("int32")
|
||||
|
||||
in_wrongs = []
|
||||
for i in range(self.in_wrong_num):
|
||||
in_wrongs.append(("in_wrong_%d" % i, np.random.randint(
|
||||
0, 10, [self.num_classes]).astype("int32")))
|
||||
|
||||
in_corrects = []
|
||||
for i in range(self.in_correct_num):
|
||||
in_corrects.append(("in_correct_%d" % i, np.random.randint(
|
||||
0, 10, [self.num_classes]).astype("int32")))
|
||||
|
||||
in_mean_ious = []
|
||||
for i in range(self.in_mean_iou_num):
|
||||
in_mean_ious.append(("in_mean_iou_%d" % i, np.random.uniform(
|
||||
0, 1, [1]).astype("float32")))
|
||||
|
||||
self.inputs = {
|
||||
'Predictions': predictions,
|
||||
'Labels': labels,
|
||||
'InWrongs': in_wrongs,
|
||||
'InCorrects': in_corrects,
|
||||
'InMeanIou': in_mean_ious
|
||||
}
|
||||
self.attrs = {'num_classes': long(self.num_classes)}
|
||||
mean_iou, out_wrong, out_correct = compute_mean_iou(
|
||||
predictions, labels, self.num_classes, in_wrongs, in_corrects,
|
||||
in_mean_ious)
|
||||
self.outputs = {
|
||||
'OutMeanIou': mean_iou,
|
||||
'OutWrong': out_wrong,
|
||||
'OutCorrect': out_correct
|
||||
}
|
||||
|
||||
def config(self):
|
||||
self.num_classes = 10
|
||||
self.image_size = [128, 128]
|
||||
self.in_wrong_num = 0
|
||||
self.in_correct_num = 0
|
||||
self.in_mean_iou_num = 0
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestCase1(TestMeanIOUOp):
|
||||
def config(self):
|
||||
self.num_classes = 5
|
||||
self.image_size = [100, 128]
|
||||
self.in_wrong_num = 2
|
||||
self.in_correct_num = 2
|
||||
self.in_mean_iou_num = 2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue