|
|
|
@ -50,10 +50,7 @@ class MiouPrecision(Metric):
|
|
|
|
|
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
|
|
|
|
predict_in = self._convert_data(inputs[0])
|
|
|
|
|
label_in = self._convert_data(inputs[1])
|
|
|
|
|
if predict_in.shape[1] != self._num_class:
|
|
|
|
|
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
|
|
|
|
|
'classes'.format(self._num_class, predict_in.shape[1]))
|
|
|
|
|
pred = np.argmax(predict_in, axis=1)
|
|
|
|
|
pred = predict_in
|
|
|
|
|
label = label_in
|
|
|
|
|
if len(label.flatten()) != len(pred.flatten()):
|
|
|
|
|
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
|
|
|
|
|