Remove StridedSlice AICPU version

pull/6903/head
peixu_ren 5 years ago
parent f4d5068fed
commit 6c0cfea75b

@ -15,7 +15,7 @@
"""grad impl.""" """grad impl."""
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse
from .grad_base import get_bprop_fn from .grad_base import get_bprop_fn
__all__ = ['get_bprop_fn'] __all__ = ['get_bprop_fn']

@ -1,39 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""array_ops"""
from .. import operations as P
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprop_getters
@bprop_getters.register(inner.StridedSliceAICPU)
def get_bprop_strided_slice_aicpu(self):
"""Generate bprop for StridedSlice"""
shape_op = P.Shape()
input_grad = G.StridedSliceGradAICPU(self.begin_mask,
self.end_mask,
self.ellipsis_mask,
self.new_axis_mask,
self.shrink_axis_mask)
def bprop(x, begin, end, strides, out, dout):
dx = input_grad(dout, shape_op(x), begin, end, strides)
return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
return bprop

@ -1418,54 +1418,6 @@ class StridedSliceGrad(PrimitiveWithInfer):
'value': None} 'value': None}
class StridedSliceGradAICPU(PrimitiveWithInfer):
"""
Performs grad of StridedSlice operation.
Args:
begin_mask (int): Start indexing the slice. Default: 0.
end_mask (int): End indexing the slice. Default: 0.
ellipsis_mask (int): An int32 mask. Default: 0.
new_axis_mask (int): An int32 mask. Default: 0.
shrink_axis_mask (int): An int32 mask. Default: 0.
Returns:
Tensor, has the same shape of input.
"""
@prim_attr_register
def __init__(self,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0):
"""Initialize StrideSliceGrad"""
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
validator.check_value_type('end_mask', end_mask, [int], self.name)
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides):
args = {"dy": dy['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
for idx, item in enumerate(shapex['value']):
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
for idx, item in enumerate(begin['value']):
validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
for idx, item in enumerate(end['value']):
validator.check_value_type("end[%d]" % idx, item, [int], self.name)
for idx, item in enumerate(strides['value']):
validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
return {'shape': shapex['value'],
'dtype': dy['dtype'],
'value': None}
class SoftplusGrad(PrimitiveWithInfer): class SoftplusGrad(PrimitiveWithInfer):
"""Computes gradient for the Log Softmax activation.""" """Computes gradient for the Log Softmax activation."""

@ -23,137 +23,6 @@ from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..operations.math_ops import _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
class StridedSliceAICPU(PrimitiveWithInfer):
r"""
Extracts a strided slice of a tensor.
Given an input tensor, this operation inserts a dimension of length 1 at the dimension.
This operation extracts a fragment of size (end-begin)/stride from the given
'input_tensor'. Starting from the position specified by the begin, the fragment
continues adding stride to the index until all dimensions are not less than end.
Note:
The stride may be negative value, which causes reverse slicing.
The shape of `begin`, `end` and `strides` must be the same.
Args:
begin_mask (int): Starting index of the slice. Default: 0.
end_mask (int): Ending index of the slice. Default: 0.
ellipsis_mask (int): An int mask. Default: 0.
new_axis_mask (int): An int mask. Default: 0.
shrink_axis_mask (int): An int mask. Default: 0.
Currently all the masks are not in used. Use default 0 only.
Inputs:
- **input_x** (Tensor) - The input Tensor.
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
constant value is allowed.
- **end** (tuple[int]) - A tuple or which represents the maximum location where to stop.
Only constant value is allowed.
- **strides** (tuple[int]) - A tuple which represents the stride continuously added
before reach the maximum location. Only constant value is allowed.
Outputs:
Tensor.
Explain with the following example.
- In the 0th dim, begin is 1, end is 2, and strides is 1,
because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
Thus, return the element with :math:`index = 1` in 0th dim, i.e., [[3, 3, 3], [4, 4, 4]].
- In the 1st dim, similarly, the interval is :math:`[0,1)`.
Based on the return value of the 0th dim, return the element with :math:`index = 0`,
i.e., [3, 3, 3].
- In the 2nd dim, similarly, the interval is :math:`[0,3)`.
Based on the return value of the 1st dim, return the element with :math:`index = 0,1,2`,
i.e., [3, 3, 3].
- Finally, the output is [3, 3, 3].
Examples
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
>>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
>>> slice = P.StridedSliceAICPU()
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 2))
>>> output.shape
(1, 1, 2)
>>> output
[[[3, 3]]]
"""
@prim_attr_register
def __init__(self,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0):
"""Initialize StrideSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
validator.check_value_type('end_mask', end_mask, [int], self.name)
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
def __infer__(self, x, begin, end, strides):
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
validator.check_value_type("begin", begin_v, [tuple], self.name)
validator.check_value_type("end", end_v, [tuple], self.name)
validator.check_value_type("strides", strides_v, [tuple], self.name)
x_shape = x['shape']
x_shp_len = len(x_shape)
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and "
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.")
ret_shape = []
append_dimensions = []
shrink_pos = bin(self.shrink_axis_mask)[::-1]
new_pos = bin(self.new_axis_mask)[::-1]
for i in range(x_shp_len):
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'
if i < (len(new_pos) - 2) and new_pos[i] == '1':
ret_shape.append(1)
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)])
continue
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1':
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name)
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name)
continue
begin_idx = begin_v[i]
end_idx = end_v[i]
strides_idx = strides_v[i]
if self.begin_mask:
begin_idx = 0
if self.end_mask:
end_idx = x_shape[i]
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name)
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name)
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name)
if strides_idx > 0:
# If sliced forward , end_idx >= begin_idx
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE)
if begin_idx < 0 < end_idx:
# Turn negative begin_idx into positive values
begin_idx = x_shape[i] + begin_idx
num_elems = (end_idx - begin_idx + strides_idx - 1) // strides_idx
else:
# If sliced backwards, end_idx <= begin_idx
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.GE)
if end_idx < 0 < begin_idx:
# Turn negative end_idx into positive values
end_idx = x_shape[i] + end_idx
num_elems = (end_idx - begin_idx + strides_idx + 1) // strides_idx
ret_shape.append(num_elems)
if append_dimensions:
ret_shape += append_dimensions[::-1]
return {'shape': ret_shape,
'dtype': x['dtype'],
'value': None}
class ExtractImagePatches(PrimitiveWithInfer): class ExtractImagePatches(PrimitiveWithInfer):
""" """
Extracts patches from images. Extracts patches from images.

@ -17,7 +17,7 @@ import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -25,7 +25,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, begin, end, strides): def __init__(self, begin, end, strides):
super(Net, self).__init__() super(Net, self).__init__()
self.strided_slice = inner.StridedSliceAICPU() self.strided_slice = P.StridedSlice()
self.begin = begin self.begin = begin
self.end = end self.end = end
self.strides = strides self.strides = strides

@ -25,7 +25,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape_x, begin, end, strides): def __init__(self, shape_x, begin, end, strides):
super(Net, self).__init__() super(Net, self).__init__()
self.strided_slice_grad = G.StridedSliceGradAICPU() self.strided_slice_grad = G.StridedSliceGrad()
self.shape_x = shape_x self.shape_x = shape_x
self.begin = begin self.begin = begin
self.end = end self.end = end

Loading…
Cancel
Save