develop op ScatterMax and dock ge process

pull/935/head
buxue 5 years ago
parent a5572f1517
commit ac86996746

@ -102,6 +102,7 @@ const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu";
const char kNameEluGrad[] = "EluGrad";
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
const char kNameScatterMax[] = "ScatterMax";
const char kNameNMSWithMask[] = "NMSWithMask";
const char kNameCheckValid[] = "CheckValid";
const char kNameSmoothL1Loss[] = "SmoothL1Loss";
@ -253,6 +254,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},
{string(kNameCheckValid), ADPT_DESC(CheckValid)},
{string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)},

@ -530,6 +530,11 @@ INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}};
// ScatterMax
INPUT_MAP(ScatterMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ScatterMax) = {{0, OUTPUT_DESC(var)}};
// CheckValid
INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}};
ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP;

@ -136,6 +136,8 @@ DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_OP_ADAPTER(ScatterNdUpdate)
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
DECLARE_OP_ADAPTER(ScatterMax)
DECLARE_OP_USE_OUTPUT(ScatterMax)
DECLARE_OP_ADAPTER(NMSWithMask)
DECLARE_OP_USE_OUTPUT(NMSWithMask)
DECLARE_OP_ADAPTER(Unpack)

@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape,
SameTypeShape, ScatterMax,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile,
@ -184,6 +184,7 @@ __all__ = [
'BoundingBoxDecode',
'L2Normalize',
'ScatterNd',
'ScatterMax',
'ResizeNearestNeighbor',
'Pad',
'MirrorPad',

@ -1953,7 +1953,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
Using given values to update tensor value, along with the input indices.
Args:
use_locking (bool): Whether protect the assignment by a lock. Defaule: True.
use_locking (bool): Whether protect the assignment by a lock. Default: True.
Inputs:
- **input_x** (Tensor) - The target tensor.
@ -1995,6 +1995,53 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_dtype
class ScatterMax(PrimitiveWithInfer):
"""
Update the value of the input tensor through the max operation.
Using given values to update tensor value through the max operation, along with the input indices,.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.
Inputs:
- **input_x** (Tensor) - The target tensor.
- **indices** (Tensor) - The index to do max operation whose data type should be int.
- **updates** (Tensor) - The tensor doing the maximum operation with 'input_x',
the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'.
Outputs:
Tensor, has the same shape and data type as `input_x`.
Examples:
>>> input_x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
>>> scatter_max = P.ScatterMax()
>>> output = scatter_max(input_x, indices, update)
[[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]
"""
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterMax"""
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of update should be [] or "
f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, update_shape: {updates_shape}.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class SpaceToDepth(PrimitiveWithInfer):
r"""
Rearrange blocks of spatial data into depth.

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save