From d74d2608cb2067c825610ca219dbee1f6ebe5b05 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 10 Jul 2020 09:54:56 +0800 Subject: [PATCH] Add attr in ROIAlign. --- mindspore/ccsrc/transform/op_declare.cc | 3 ++- mindspore/ops/_op_impl/tbe/roi_align.py | 2 +- mindspore/ops/operations/nn_ops.py | 6 +++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index fd8ce624a9..a85a681836 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -610,7 +610,8 @@ OUTPUT_MAP(ROIAlign) = {{0, OUTPUT_DESC(y)}}; ATTR_MAP(ROIAlign) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits())}, {"pooled_width", ATTR_DESC(pooled_width, AnyTraits())}, {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits())}, - {"sample_num", ATTR_DESC(sample_num, AnyTraits())}}; + {"sample_num", ATTR_DESC(sample_num, AnyTraits())}, + {"roi_end_mode", ATTR_DESC(roi_end_mode, AnyTraits())}}; // ROIAlignGrad INPUT_MAP(ROIAlignGrad) = {{1, INPUT_DESC(ydiff)}, {2, INPUT_DESC(rois)}}; diff --git a/mindspore/ops/_op_impl/tbe/roi_align.py b/mindspore/ops/_op_impl/tbe/roi_align.py index bc4eed80ce..d392651217 100644 --- a/mindspore/ops/_op_impl/tbe/roi_align.py +++ b/mindspore/ops/_op_impl/tbe/roi_align.py @@ -27,7 +27,7 @@ roi_align_op_info = TBERegOp("ROIAlign") \ .attr("pooled_height", "required", "int", "all") \ .attr("pooled_width", "required", "int", "all") \ .attr("sample_num", "optional", "int", "all", "2") \ - .attr("roi_end_mode", "optional", "0,1", "1") \ + .attr("roi_end_mode", "optional", "int", "0,1", "1") \ .input(0, "features", False, "required", "all") \ .input(1, "rois", False, "required", "all") \ .input(2, "rois_n", False, "optional", "all") \ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 80b877765c..0d2499c0a3 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2695,6 +2695,7 @@ class ROIAlign(PrimitiveWithInfer): feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the input feature map, the `spatial_scale` should be `fea_h / ori_h`. sample_num (int): Number of sampling points. Default: 2. + roi_end_mode (int): Number must be 0 or 1. Default: 1. Inputs: - **features** (Tensor) - The input features, whose shape should be `(N, C, H, W)`. @@ -2717,16 +2718,19 @@ class ROIAlign(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2): + def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2, roi_end_mode=1): """init ROIAlign""" validator.check_value_type("pooled_height", pooled_height, [int], self.name) validator.check_value_type("pooled_width", pooled_width, [int], self.name) validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) validator.check_value_type("sample_num", sample_num, [int], self.name) + validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name) + validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name) self.pooled_height = pooled_height self.pooled_width = pooled_width self.spatial_scale = spatial_scale self.sample_num = sample_num + self.roi_end_mode = roi_end_mode def infer_shape(self, inputs_shape, rois_shape): return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width]