parent
327d1eb5fb
commit
a47f5493d5
@ -0,0 +1,37 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""array_ops"""
|
||||
|
||||
from ..operations import _grad_ops as G
|
||||
from ..operations import _inner_ops as inner
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from .grad_base import bprop_getters
|
||||
|
||||
|
||||
@bprop_getters.register(inner.StridedSliceAICPU)
|
||||
def get_bprop_strided_slice_aicpu(self):
|
||||
"""Generate bprop for StridedSlice"""
|
||||
input_grad = G.StridedSliceGradAICPU(self.begin_mask,
|
||||
self.end_mask,
|
||||
self.ellipsis_mask,
|
||||
self.new_axis_mask,
|
||||
self.shrink_axis_mask)
|
||||
|
||||
def bprop(x, begin, end, strides, out, dout):
|
||||
dx = input_grad(dout, shape_op(x), begin, end, strides)
|
||||
return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
|
||||
|
||||
return bprop
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""StridedSlice op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "input", "required") \
|
||||
.input(1, "begin", "required") \
|
||||
.input(2, "end", "required") \
|
||||
.input(3, "stride", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.attr("begin_mask", "int") \
|
||||
.attr("end_mask", "int") \
|
||||
.attr("ellipsis_mask", "int") \
|
||||
.attr("new_axis_mask", "int") \
|
||||
.attr("shrink_axis_mask", "int") \
|
||||
.dtype_format(DataType.F32_NCHW,
|
||||
DataType.I32_NCHW,
|
||||
DataType.I32_NCHW,
|
||||
DataType.I32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(strided_slice_op_info)
|
||||
def _strided_slice_aicpu():
|
||||
"""StridedSlice AiCPU register"""
|
||||
return
|
@ -0,0 +1,43 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""StridedSliceGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "dy", "required") \
|
||||
.input(1, "shape", "required") \
|
||||
.input(2, "begin", "required") \
|
||||
.input(3, "end", "required") \
|
||||
.input(4, "stride", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.attr("begin_mask", "int") \
|
||||
.attr("end_mask", "int") \
|
||||
.attr("ellipsis_mask", "int") \
|
||||
.attr("new_axis_mask", "int") \
|
||||
.attr("shrink_axis_mask", "int") \
|
||||
.dtype_format(DataType.F32_NCHW,
|
||||
DataType.I32_NCHW,
|
||||
DataType.I32_NCHW,
|
||||
DataType.I32_NCHW,
|
||||
DataType.I32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(strided_slice_grad_op_info)
|
||||
def _strided_slice_grad_aicpu():
|
||||
"""StridedSliceGrad AiCPU register"""
|
||||
return
|
@ -0,0 +1,51 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, begin, end, strides):
|
||||
super(Net, self).__init__()
|
||||
self.strided_slice = inner.StridedSliceAICPU()
|
||||
self.begin = begin
|
||||
self.end = end
|
||||
self.strides = strides
|
||||
|
||||
def construct(self, input):
|
||||
return self.strided_slice(input, self.begin, self.end, self.strides)
|
||||
|
||||
|
||||
input_x = np.array([[[0, 1, 2], [3, 4, 5]],
|
||||
[[6, 7, 8], [9, 10, 11]],
|
||||
[[12, 13, 14], [15, 16, 17]]
|
||||
]).astype(np.float32)
|
||||
begin = (1, 0, 0)
|
||||
end = (2, 2, 3)
|
||||
strides = (1, 1, 2)
|
||||
|
||||
|
||||
def test_net():
|
||||
net = Net(begin, end, strides)
|
||||
tinput = Tensor(input_x)
|
||||
output = net(tinput)
|
||||
print(output.asnumpy())
|
||||
assert np.all([[[6, 8], [9, 11]]] == output.asnumpy())
|
@ -0,0 +1,53 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, shape_x, begin, end, strides):
|
||||
super(Net, self).__init__()
|
||||
self.strided_slice_grad = G.StridedSliceGradAICPU()
|
||||
self.shape_x = shape_x
|
||||
self.begin = begin
|
||||
self.end = end
|
||||
self.strides = strides
|
||||
|
||||
def construct(self, dy):
|
||||
return self.strided_slice_grad(dy, self.shape_x, self.begin, self.end, self.strides)
|
||||
|
||||
|
||||
dy = np.array([[[6, 8], [9, 11]]]).astype(np.float32)
|
||||
shape_x = (3, 2, 3)
|
||||
begin = (1, 0, 0)
|
||||
end = (2, 2, 3)
|
||||
strides = (1, 1, 2)
|
||||
|
||||
|
||||
def test_net():
|
||||
net = Net(shape_x, begin, end, strides)
|
||||
tdy = Tensor(dy)
|
||||
output = net(tdy)
|
||||
print(output.asnumpy())
|
||||
assert np.all([[[0, 0, 0], [0, 0, 0]],
|
||||
[[6, 0, 8], [9, 0, 11]],
|
||||
[[0, 0, 0], [0, 0, 0]]
|
||||
] == output.asnumpy())
|
Loading…
Reference in new issue