From dfe2f94993140fc71f44d478a833830856375488 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Sun, 12 Apr 2020 20:07:55 +0800 Subject: [PATCH] Enhance some op error message (#23711) * enhance error msg. test=develop * print invalid argment * update comment, test=develop * fix enforce not meet specification * fix enforce not meet specification, test=develop --- .../detection/locality_aware_nms_op.cc | 62 +++++---- .../detection/mine_hard_examples_op.cc | 131 +++++++++++------- .../detection/roi_perspective_transform_op.cc | 73 ++++++---- python/paddle/fluid/layers/detection.py | 27 +++- python/paddle/fluid/layers/distributions.py | 48 +++++-- .../tests/unittests/test_distributions.py | 77 ++++++++++ .../unittests/test_locality_aware_nms_op.py | 55 ++++++++ .../test_roi_perspective_transform_op.py | 38 +++++ 8 files changed, 401 insertions(+), 110 deletions(-) diff --git a/paddle/fluid/operators/detection/locality_aware_nms_op.cc b/paddle/fluid/operators/detection/locality_aware_nms_op.cc index 36e9d60280..8422a1fa6c 100644 --- a/paddle/fluid/operators/detection/locality_aware_nms_op.cc +++ b/paddle/fluid/operators/detection/locality_aware_nms_op.cc @@ -26,37 +26,49 @@ class LocalityAwareNMSOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("BBoxes"), true, - "Input(BBoxes) of MultiClassNMS should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("Scores"), true, - "Input(Scores) of MultiClassNMS should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - "Output(Out) of MultiClassNMS should not be null."); + OP_INOUT_CHECK(ctx->HasInput("BBoxes"), "Input", "BBoxes", + "locality_aware_nms"); + OP_INOUT_CHECK(ctx->HasInput("Scores"), "Input", "Scores", + "locality_aware_nms"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "locality_aware_nms"); auto box_dims = ctx->GetInputDim("BBoxes"); auto score_dims = ctx->GetInputDim("Scores"); auto score_size = score_dims.size(); if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(score_size, 3, "The rank of Input(Scores) must be 3"); - PADDLE_ENFORCE_EQ(box_dims.size(), 3, - "The rank of Input(BBoxes) must be 3"); - - PADDLE_ENFORCE_EQ(box_dims[2] == 4 || box_dims[2] == 8 || - box_dims[2] == 16 || box_dims[2] == 24 || - box_dims[2] == 32, - true, - "The last dimension of Input(BBoxes) must be 4 or 8, " - "represents the layout of coordinate " - "[xmin, ymin, xmax, ymax] or " - "4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or " - "8 points: [xi, yi] i= 1,2,...,8 or " - "12 points: [xi, yi] i= 1,2,...,12 or " - "16 points: [xi, yi] i= 1,2,...,16"); - PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2], - "The 2nd dimension of Input(BBoxes) must be equal to " - "last dimension of Input(Scores), which represents the " - "predicted bboxes."); + PADDLE_ENFORCE_EQ( + score_size, 3, + platform::errors::InvalidArgument( + "The rank of Input(Scores) must be 3. But received %d.", + score_size)); + PADDLE_ENFORCE_EQ( + box_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of Input(BBoxes) must be 3. But received %d.", + box_dims.size())); + PADDLE_ENFORCE_EQ( + box_dims[2] == 4 || box_dims[2] == 8 || box_dims[2] == 16 || + box_dims[2] == 24 || box_dims[2] == 32, + true, platform::errors::InvalidArgument( + "The last dimension of Input(BBoxes) must be 4 or 8, " + "represents the layout of coordinate " + "[xmin, ymin, xmax, ymax] or " + "4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or " + "8 points: [xi, yi] i= 1,2,...,8 or " + "12 points: [xi, yi] i= 1,2,...,12 or " + "16 points: [xi, yi] i= 1,2,...,16. " + "But received %d.", + box_dims[2])); + PADDLE_ENFORCE_EQ( + box_dims[1], score_dims[2], + platform::errors::InvalidArgument( + "The 2nd dimension of Input(BBoxes) must be equal to " + "last dimension of Input(Scores), which represents the " + "predicted bboxes. But received the 2nd dimension of " + "Input(BBoxes) was %d, last dimension of Input(Scores) was %d.", + box_dims[1], score_dims[2])); } // Here the box_dims[0] is not the real dimension of output. // It will be rewritten in the computing kernel. diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index 24d59c94cb..d69a7435cd 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -165,85 +165,118 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("ClsLoss"), - "Input(ClsLoss) of MineHardExamplesOp should not be null."); - PADDLE_ENFORCE( - ctx->HasInput("MatchIndices"), - "Input(MatchIndices) of MineHardExamplesOp should not be null."); - PADDLE_ENFORCE( - ctx->HasInput("MatchDist"), - "Input(MatchDist) of MineHardExamplesOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("NegIndices"), - "Output(NegIndices) of MineHardExamplesOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("UpdatedMatchIndices"), - "Output(UpdatedMatchIndices) of MineHardExamplesOp should " - "not be null."); + OP_INOUT_CHECK(ctx->HasInput("ClsLoss"), "Input", "ClsLoss", + "mine_hard_examples"); + OP_INOUT_CHECK(ctx->HasInput("MatchIndices"), "Input", "MatchIndices", + "mine_hard_examples"); + OP_INOUT_CHECK(ctx->HasInput("MatchDist"), "Input", "MatchDist", + "mine_hard_examples"); + OP_INOUT_CHECK(ctx->HasOutput("NegIndices"), "Output", "NegIndices", + "mine_hard_examples"); + OP_INOUT_CHECK(ctx->HasOutput("UpdatedMatchIndices"), "Output", + "UpdatedMatchIndices", "mine_hard_examples"); auto cls_loss_dims = ctx->GetInputDim("ClsLoss"); auto idx_dims = ctx->GetInputDim("MatchIndices"); auto dis_dims = ctx->GetInputDim("MatchDist"); PADDLE_ENFORCE_EQ(cls_loss_dims.size(), 2UL, - "The shape of ClsLoss is [N, Np]."); - PADDLE_ENFORCE_EQ(idx_dims.size(), 2UL, - "The shape of MatchIndices is [N, Np]."); + platform::errors::InvalidArgument( + "The shape of ClsLoss is [N, Np]. But received %d.", + cls_loss_dims.size())); + PADDLE_ENFORCE_EQ( + idx_dims.size(), 2UL, + platform::errors::InvalidArgument( + "The shape of MatchIndices is [N, Np]. But received %d.", + idx_dims.size())); PADDLE_ENFORCE_EQ(dis_dims.size(), 2UL, - "The shape of MatchDist is [N, Np]."); + platform::errors::InvalidArgument( + "The shape of MatchDist is [N, Np]. But received %d.", + dis_dims.size())); if (ctx->HasInput("LocLoss")) { auto loc_loss_dims = ctx->GetInputDim("LocLoss"); PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL, - "The shape of LocLoss is [N, Np]."); + platform::errors::InvalidArgument( + "The shape of LocLoss is [N, Np]. But received %d.", + loc_loss_dims.size())); if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - cls_loss_dims[0], loc_loss_dims[0], - "Batch size of ClsLoss and LocLoss must be the same."); - PADDLE_ENFORCE_EQ( - cls_loss_dims[1], loc_loss_dims[1], - "Prior box number of ClsLoss and LocLoss must be the same."); + PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0], + platform::errors::InvalidArgument( + "Batch size of ClsLoss and LocLoss must be the " + "same. But received batch size of ClsLoss was " + "%d, batch size of LocLoss was %d.", + cls_loss_dims[0], loc_loss_dims[0])); + PADDLE_ENFORCE_EQ(cls_loss_dims[1], loc_loss_dims[1], + platform::errors::InvalidArgument( + "Prior box number of ClsLoss and LocLoss must be " + "the same. But received box number of ClsLoss " + "was %d, box number of LocLoss was %d.", + cls_loss_dims[1], loc_loss_dims[1])); } } if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - cls_loss_dims[0], idx_dims[0], - "Batch size of ClsLoss and MatchIndices must be the same."); - PADDLE_ENFORCE_EQ( - cls_loss_dims[1], idx_dims[1], - "Prior box number of ClsLoss and MatchIndices must be the same."); - - PADDLE_ENFORCE_EQ( - cls_loss_dims[0], dis_dims[0], - "Batch size of ClsLoss and MatchDist must be the same."); + PADDLE_ENFORCE_EQ(cls_loss_dims[0], idx_dims[0], + platform::errors::InvalidArgument( + "Batch size of ClsLoss and MatchIndices must be " + "the same. But received batch size of ClsLoss was " + "%d, batch size of MatchIndices was %d.", + cls_loss_dims[0], idx_dims[0])); PADDLE_ENFORCE_EQ( cls_loss_dims[1], idx_dims[1], - "Prior box number of ClsLoss and MatchDist must be the same."); + platform::errors::InvalidArgument( + "Prior box number of ClsLoss and " + "MatchIndices must be the same. But received box number of " + "ClsLoss was %d, box number of MatchIndices was %d.", + cls_loss_dims[1], idx_dims[1])); + + PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0], + platform::errors::InvalidArgument( + "Batch size of ClsLoss and MatchDist must be the " + "same. But received batch size of ClsLoss was %d, " + "batch size of MatchDist was %d.", + cls_loss_dims[0], dis_dims[0])); + PADDLE_ENFORCE_EQ(cls_loss_dims[1], idx_dims[1], + platform::errors::InvalidArgument( + "Prior box number of ClsLoss and MatchDist must be " + "the same. But received box number of ClsLoss was " + "%d, box number of MatchDist was %d.", + cls_loss_dims[1], idx_dims[1])); } auto mining_type = GetMiningType(ctx->Attrs().Get("mining_type")); PADDLE_ENFORCE_NE(mining_type, MiningType::kNone, - "mining_type must be hard_example or max_negative"); + platform::errors::InvalidArgument( + "mining_type must be hard_example or max_negative")); if (mining_type == MiningType::kMaxNegative) { auto neg_pos_ratio = ctx->Attrs().Get("neg_pos_ratio"); auto neg_dist_threshold = ctx->Attrs().Get("neg_dist_threshold"); - PADDLE_ENFORCE_GT( - neg_pos_ratio, 0.0f, - "neg_pos_ratio must greater than zero in max_negative mode"); - PADDLE_ENFORCE_LT( - neg_dist_threshold, 1.0f, - "neg_dist_threshold must less than one in max_negative mode"); - PADDLE_ENFORCE_GT( - neg_dist_threshold, 0.0f, - "neg_dist_threshold must greater than zero in max_negative mode"); + PADDLE_ENFORCE_GT(neg_pos_ratio, 0.0f, + platform::errors::InvalidArgument( + "neg_pos_ratio must greater than zero in " + "max_negative mode. But received %f.", + neg_pos_ratio)); + PADDLE_ENFORCE_LT(neg_dist_threshold, 1.0f, + platform::errors::InvalidArgument( + "neg_dist_threshold must less than one in " + "max_negative mode. But received %f.", + neg_dist_threshold)); + PADDLE_ENFORCE_GT(neg_dist_threshold, 0.0f, + platform::errors::InvalidArgument( + "neg_dist_threshold must greater " + "than zero in max_negative mode. But received %f.", + neg_dist_threshold)); } else if (mining_type == MiningType::kHardExample) { auto sample_size = ctx->Attrs().Get("sample_size"); - PADDLE_ENFORCE_GT( - sample_size, 0, - "sample_size must greater than zero in hard_example mode"); + PADDLE_ENFORCE_GT(sample_size, 0, + platform::errors::InvalidArgument( + "sample_size must greater than zero in " + "hard_example mode. But received %d.", + sample_size)); } ctx->SetOutputDim("UpdatedMatchIndices", idx_dims); diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index 8b6bab9196..4d0c9da2ee 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -471,34 +471,54 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ROIPerspectiveTransformOp should not be null."); - PADDLE_ENFORCE( - ctx->HasInput("ROIs"), - "Input(ROIs) of ROIPerspectiveTransformOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("Out"), - "Output(Out) of ROIPerspectiveTransformOp should not be null."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + "roi_perspective_transform"); + OP_INOUT_CHECK(ctx->HasInput("ROIs"), "Input", "ROIs", + "roi_perspective_transform"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Ountput", "Out", + "roi_perspective_transform"); + auto input_dims = ctx->GetInputDim("X"); auto rois_dims = ctx->GetInputDim("ROIs"); - PADDLE_ENFORCE(input_dims.size() == 4, - "The format of input tensor is NCHW."); - PADDLE_ENFORCE(rois_dims.size() == 2, - "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" - "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]"); - PADDLE_ENFORCE(rois_dims[1] == 8, - "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" - "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]."); + + PADDLE_ENFORCE_EQ(input_dims.size(), 4, + platform::errors::InvalidArgument( + "The format of input tensor must be NCHW. But " + "received input dims is %d.", + input_dims.size())); + PADDLE_ENFORCE_EQ( + rois_dims.size(), 2, + platform::errors::InvalidArgument( + "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]. But received " + "rois dims is %d", + rois_dims.size())); + PADDLE_ENFORCE_EQ( + rois_dims[1], 8, + platform::errors::InvalidArgument( + "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]. But received %d", + rois_dims[1])); + int transformed_height = ctx->Attrs().Get("transformed_height"); int transformed_width = ctx->Attrs().Get("transformed_width"); float spatial_scale = ctx->Attrs().Get("spatial_scale"); - PADDLE_ENFORCE_GT(transformed_height, 0, - "The transformed output height must greater than 0"); - PADDLE_ENFORCE_GT(transformed_width, 0, - "The transformed output width must greater than 0"); - PADDLE_ENFORCE_GT(spatial_scale, 0.0f, - "The spatial scale must greater than 0"); + PADDLE_ENFORCE_GT( + transformed_height, 0, + platform::errors::InvalidArgument("The transformed output height must " + "greater than 0. But received %d.", + transformed_height)); + PADDLE_ENFORCE_GT( + transformed_width, 0, + platform::errors::InvalidArgument("The transformed output width must " + "greater than 0. But received %d.", + transformed_width)); + PADDLE_ENFORCE_GT( + spatial_scale, 0.0f, + platform::errors::InvalidArgument( + "The spatial scale must greater than 0. But received %f.", + spatial_scale)); std::vector out_dims_v({rois_dims[0], // num_rois input_dims[1], // channels static_cast(transformed_height), @@ -536,10 +556,11 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "The gradient of Out should not be null."); - PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), - "The gradient of X should not be null."); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "roi_perspective_transform_grad"); + OP_INOUT_CHECK(ctx->HasOutputs(framework::GradVarName("X")), "Output", + "X@Grad", "roi_perspective_transform_grad"); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index d85d648773..2152fb063e 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -2420,6 +2420,17 @@ def roi_perspective_transform(input, rois = fluid.data(name='rois', shape=[None, 8], lod_level=1, dtype='float32') out, mask, transform_matrix = fluid.layers.roi_perspective_transform(x, rois, 7, 7, 1.0) """ + check_variable_and_dtype(input, 'input', ['float32'], + 'roi_perspective_transform') + check_variable_and_dtype(rois, 'rois', ['float32'], + 'roi_perspective_transform') + check_type(transformed_height, 'transformed_height', int, + 'roi_perspective_transform') + check_type(transformed_width, 'transformed_width', int, + 'roi_perspective_transform') + check_type(spatial_scale, 'spatial_scale', float, + 'roi_perspective_transform') + helper = LayerHelper('roi_perspective_transform', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) @@ -3239,10 +3250,10 @@ def locality_aware_nms(bboxes, nms_top_k (int): Maximum number of detections to be kept according to the confidences after the filtering detections based on score_threshold. - nms_threshold (float): The threshold to be used in NMS. Default: 0.3 - nms_eta (float): The threshold to be used in NMS. Default: 1.0 keep_top_k (int): Number of total bboxes to be kept per image after NMS step. -1 means keeping all bboxes after NMS step. + nms_threshold (float): The threshold to be used in NMS. Default: 0.3 + nms_eta (float): The threshold to be used in NMS. Default: 1.0 normalized (bool): Whether detections are normalized. Default: True name(str): Name of the locality aware nms op, please refer to :ref:`api_guide_Name` . Default: None. @@ -3277,6 +3288,18 @@ def locality_aware_nms(bboxes, keep_top_k=200, normalized=False) """ + check_variable_and_dtype(bboxes, 'bboxes', ['float32', 'float64'], + 'locality_aware_nms') + check_variable_and_dtype(scores, 'scores', ['float32', 'float64'], + 'locality_aware_nms') + check_type(background_label, 'background_label', int, 'locality_aware_nms') + check_type(score_threshold, 'score_threshold', float, 'locality_aware_nms') + check_type(nms_top_k, 'nms_top_k', int, 'locality_aware_nms') + check_type(nms_eta, 'nms_eta', float, 'locality_aware_nms') + check_type(nms_threshold, 'nms_threshold', float, 'locality_aware_nms') + check_type(keep_top_k, 'keep_top_k', int, 'locality_aware_nms') + check_type(normalized, 'normalized', bool, 'locality_aware_nms') + shape = scores.shape assert len(shape) == 3, "dim size of scores must be 3" assert shape[ diff --git a/python/paddle/fluid/layers/distributions.py b/python/paddle/fluid/layers/distributions.py index 396ab443a4..c5cf4c7975 100644 --- a/python/paddle/fluid/layers/distributions.py +++ b/python/paddle/fluid/layers/distributions.py @@ -22,6 +22,8 @@ import math import numpy as np import warnings +from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype + __all__ = ['Uniform', 'Normal', 'Categorical', 'MultivariateNormalDiag'] @@ -175,6 +177,11 @@ class Uniform(Distribution): """ def __init__(self, low, high): + check_type(low, 'low', (float, np.ndarray, tensor.Variable, list), + 'Uniform') + check_type(high, 'high', (float, np.ndarray, tensor.Variable, list), + 'Uniform') + self.all_arg_is_float = False self.batch_size_unknown = False if self._validate_args(low, high): @@ -197,6 +204,9 @@ class Uniform(Distribution): Variable: A tensor with prepended dimensions shape.The data type is float32. """ + check_type(shape, 'shape', (list), 'sample') + check_type(seed, 'seed', (int), 'sample') + batch_shape = list((self.low + self.high).shape) if self.batch_size_unknown: output_shape = shape + batch_shape @@ -228,6 +238,9 @@ class Uniform(Distribution): Variable: log probability.The data type is same with value. """ + check_variable_and_dtype(value, 'value', ['float32', 'float64'], + 'log_prob') + lb_bool = control_flow.less_than(self.low, value) ub_bool = control_flow.less_than(value, self.high) lb = tensor.cast(lb_bool, dtype=value.dtype) @@ -271,7 +284,8 @@ class Normal(Distribution): Examples: .. code-block:: python - + + import numpy as np from paddle.fluid import layers from paddle.fluid.layers import Normal @@ -287,10 +301,6 @@ class Normal(Distribution): # Both have mean 1, but different standard deviations. dist = Normal(loc=1., scale=[11., 22.]) - # Define a batch of two scalar valued Normals. - # Both have mean 1, but different standard deviations. - dist = Normal(loc=1., scale=[11., 22.]) - # Complete example value_npdata = np.array([0.8], dtype="float32") value_tensor = layers.create_tensor(dtype="float32") @@ -310,6 +320,11 @@ class Normal(Distribution): """ def __init__(self, loc, scale): + check_type(loc, 'loc', (float, np.ndarray, tensor.Variable, list), + 'Normal') + check_type(scale, 'scale', (float, np.ndarray, tensor.Variable, list), + 'Normal') + self.batch_size_unknown = False self.all_arg_is_float = False if self._validate_args(loc, scale): @@ -332,6 +347,10 @@ class Normal(Distribution): Variable: A tensor with prepended dimensions shape.The data type is float32. """ + + check_type(shape, 'shape', (list), 'sample') + check_type(seed, 'seed', (int), 'sample') + batch_shape = list((self.loc + self.scale).shape) if self.batch_size_unknown: @@ -374,6 +393,9 @@ class Normal(Distribution): Variable: log probability.The data type is same with value. """ + check_variable_and_dtype(value, 'value', ['float32', 'float64'], + 'log_prob') + var = self.scale * self.scale log_scale = nn.log(self.scale) return -1. * ((value - self.loc) * (value - self.loc)) / ( @@ -389,7 +411,9 @@ class Normal(Distribution): Variable: kl-divergence between two normal distributions.The data type is float32. """ - assert isinstance(other, Normal), "another distribution must be Normal" + + check_type(other, 'other', Normal, 'kl_divergence') + var_ratio = self.scale / other.scale var_ratio = (var_ratio * var_ratio) t1 = (self.loc - other.loc) / other.scale @@ -451,6 +475,9 @@ class Categorical(Distribution): Args: logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32. """ + check_type(logits, 'logits', (np.ndarray, tensor.Variable, list), + 'Categorical') + if self._validate_args(logits): self.logits = logits else: @@ -466,7 +493,7 @@ class Categorical(Distribution): Variable: kl-divergence between two Categorical distributions. """ - assert isinstance(other, Categorical) + check_type(other, 'other', Categorical, 'kl_divergence') logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) other_logits = other.logits - nn.reduce_max( @@ -569,6 +596,11 @@ class MultivariateNormalDiag(Distribution): """ def __init__(self, loc, scale): + check_type(loc, 'loc', (np.ndarray, tensor.Variable, list), + 'MultivariateNormalDiag') + check_type(scale, 'scale', (np.ndarray, tensor.Variable, list), + 'MultivariateNormalDiag') + if self._validate_args(loc, scale): self.loc = loc self.scale = scale @@ -620,7 +652,7 @@ class MultivariateNormalDiag(Distribution): Variable: kl-divergence between two Multivariate Normal distributions. The data type is float32. """ - assert isinstance(other, MultivariateNormalDiag) + check_type(other, 'other', MultivariateNormalDiag, 'kl_divergence') tr_cov_matmul = nn.reduce_sum(self._inv(other.scale) * self.scale) loc_matmul_cov = nn.matmul((other.loc - self.loc), diff --git a/python/paddle/fluid/tests/unittests/test_distributions.py b/python/paddle/fluid/tests/unittests/test_distributions.py index 3de9c10e6d..8387441718 100644 --- a/python/paddle/fluid/tests/unittests/test_distributions.py +++ b/python/paddle/fluid/tests/unittests/test_distributions.py @@ -574,5 +574,82 @@ class DistributionTest(unittest.TestCase): output_kl_np, gt_kl_np, rtol=tolerance, atol=tolerance) +class DistributionTestError(unittest.TestCase): + def test_normal_error(self): + loc = int(1) + scale = int(1) + + # type of loc and scale must be float, list, numpy.ndarray, Variable + self.assertRaises(TypeError, Normal, loc, 1.0) + self.assertRaises(TypeError, Normal, 1.0, scale) + + normal = Normal(0.0, 1.0) + + value = [1.0, 2.0] + # type of value must be variable + self.assertRaises(TypeError, normal.log_prob, value) + + shape = 1.0 + # type of shape must be list + self.assertRaises(TypeError, normal.sample, shape) + + seed = 1.0 + # type of seed must be int + self.assertRaises(TypeError, normal.sample, [2, 3], seed) + + normal_other = Uniform(1.0, 2.0) + # type of other must be an instance of Normal + self.assertRaises(TypeError, normal.kl_divergence, normal_other) + + def test_uniform_error(self): + low = int(1) + high = int(1) + + # type of loc and scale must be float, list, numpy.ndarray, Variable + self.assertRaises(TypeError, Uniform, low, 1.0) + self.assertRaises(TypeError, Uniform, 1.0, high) + + uniform = Uniform(0.0, 1.0) + + value = [1.0, 2.0] + # type of value must be variable + self.assertRaises(TypeError, uniform.log_prob, value) + + shape = 1.0 + # type of shape must be list + self.assertRaises(TypeError, uniform.sample, shape) + + seed = 1.0 + # type of seed must be int + self.assertRaises(TypeError, uniform.sample, [2, 3], seed) + + def test_categorical_error(self): + logit = 1.0 + + # type of loc and scale must be list, numpy.ndarray, Variable + self.assertRaises(TypeError, Categorical, logit) + + categorical = Categorical([-0.602, -0.602]) + + categorical_other = Normal(1.0, 2.0) + # type of other must be an instance of Normal + self.assertRaises(TypeError, categorical.kl_divergence, + categorical_other) + + def test_multivariate_normal_diag_error(self): + loc = 1.0 + scale = 1.0 + + # type of loc and scale must be list, numpy.ndarray, Variable + self.assertRaises(TypeError, MultivariateNormalDiag, loc, [1.0]) + self.assertRaises(TypeError, MultivariateNormalDiag, [1.0], scale) + + mnd = MultivariateNormalDiag([0.3, 0.5], [[0.4, 0], [0, 0.5]]) + + categorical_other = Normal(1.0, 2.0) + # type of other must be an instance of Normal + self.assertRaises(TypeError, mnd.kl_divergence, categorical_other) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py b/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py index e4845c90a6..71e2e6fe59 100644 --- a/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py @@ -321,5 +321,60 @@ class TestLocalityAwareNMSAPI(unittest.TestCase): normalized=False) +class TestLocalityAwareNMSError(unittest.TestCase): + def test_error(self): + boxes = fluid.data(name='bboxes', shape=[None, 81, 8], dtype='float32') + scores = fluid.data(name='scores', shape=[None, 1, 81], dtype='float32') + + boxes_int = fluid.data( + name='bboxes_int', shape=[None, 81, 8], dtype='int32') + scores_int = fluid.data( + name='scores_int', shape=[None, 1, 81], dtype='int32') + boxes_tmp = [1, 2] + scores_tmp = [1, 2] + + # type of boxes and scores must be variable + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes_tmp, + scores, 0.5, 400, 200) + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores_tmp, 0.5, 400, 200) + + # dtype of boxes and scores must in ['float32', 'float64'] + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes_int, + scores, 0.5, 400, 200) + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores_int, 0.5, 400, 200) + + score_threshold = int(1) + # type of score_threshold must be float + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores, score_threshold, 400, 200) + + nms_top_k = 400.5 + # type of num_top_k must be int + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores, 0.5, nms_top_k, 200) + + keep_top_k = 200.5 + # type of keep_top_k must be int + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores, 0.5, 400, keep_top_k) + + nms_threshold = int(0) + # type of nms_threshold must be int + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores, 0.5, 400, 200, nms_threshold) + + nms_eta = int(1) + # type of nms_eta must be float + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores, 0.5, 400, 200, 0.5, nms_eta) + + bg_label = 1.5 + # type of background_label must be int + self.assertRaises(TypeError, fluid.layers.locality_aware_nms, boxes, + scores, 0.5, 400, 200, 0.5, 1.0, bg_label) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py index 0a302f5efc..d4e48ac8a5 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py @@ -22,6 +22,7 @@ import paddle.compat as cpt from op_test import OpTest from math import sqrt from math import floor +from paddle import fluid def gt_e(a, b): @@ -313,6 +314,43 @@ class TestROIPoolOp(OpTest): [np.product(self.outputs['Out'].shape), 4]).astype("float32") self.check_grad(['X'], 'Out') + def test_errors(self): + x = fluid.data(name='x', shape=[100, 256, 28, 28], dtype='float32') + rois = fluid.data( + name='rois', shape=[None, 8], lod_level=1, dtype='float32') + + x_int = fluid.data( + name='x_int', shape=[100, 256, 28, 28], dtype='int32') + rois_int = fluid.data( + name='rois_int', shape=[None, 8], lod_level=1, dtype='int32') + x_tmp = [1, 2] + rois_tmp = [1, 2] + + # type of intput and rois must be variable + self.assertRaises(TypeError, fluid.layers.roi_perspective_transform, + x_tmp, rois, 7, 7) + self.assertRaises(TypeError, fluid.layers.roi_perspective_transform, x, + rois_tmp, 7, 7) + + # dtype of intput and rois must be float32 + self.assertRaises(TypeError, fluid.layers.roi_perspective_transform, + x_int, rois, 7, 7) + self.assertRaises(TypeError, fluid.layers.roi_perspective_transform, x, + rois_int, 7, 7) + + height = 7.5 + width = 7.5 + # type of transformed_height and transformed_width must be int + self.assertRaises(TypeError, fluid.layers.roi_perspective_transform, x, + rois, height, 7) + self.assertRaises(TypeError, fluid.layers.roi_perspective_transform, x, + rois, 7, width) + + scale = int(2) + # type of spatial_scale must be float + self.assertRaises(TypeError, fluid.layers.roi_perspective_transform, x, + rois, 7, 7, scale) + if __name__ == '__main__': unittest.main()