!732 dock AcoshGrad for GE and AvgPool AvgPoolGrad for VM
Merge pull request !732 from zhangbuxue/dock_AcoshGrad_AvgPool_AvgPoolGrad_for_vmpull/732/MERGE
commit
9ffb3993e8
@ -0,0 +1,39 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AvgPool op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
avg_pool_op_info = TBERegOp("AvgPool") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("avg_pool.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("avg_pool") \
|
||||
.partial_flag(True) \
|
||||
.attr("ksize", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("padding", "required", "str", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(avg_pool_op_info)
|
||||
def _avg_pool_tbe():
|
||||
"""AvgPool TBE register"""
|
||||
return
|
@ -0,0 +1,42 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AvgPoolGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
avg_pool_grad_op_info = TBERegOp("AvgPoolGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("avg_pool_grad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("avg_pool_grad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("x_origin", "required", "listInt", "all") \
|
||||
.attr("ksize", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("padding", "required", "str", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.input(0, "input_grad", False, "required", "all") \
|
||||
.input(1, "mean_matrix", False, "optional", "all") \
|
||||
.input(2, "kernel_matrix", False, "optional", "all") \
|
||||
.output(0, "out_grad", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(avg_pool_grad_op_info)
|
||||
def _avg_pool_grad_tbe():
|
||||
"""AvgPoolGrad TBE register"""
|
||||
return
|
@ -0,0 +1,98 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Inner operators."""
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
class ExtractImagePatches(PrimitiveWithInfer):
|
||||
"""
|
||||
Extract patches from images.
|
||||
The input tensor must be a 4-D tensor and the data format is NHWC.
|
||||
|
||||
Args:
|
||||
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
|
||||
and the format is [1, ksize_row, ksize_col, 1].
|
||||
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
|
||||
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
|
||||
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
|
||||
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
|
||||
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
|
||||
not case sensitive. Default: "valid".
|
||||
|
||||
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
|
||||
|
||||
- valid: Means that the patch area taken must be completely contained in the original image.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
|
||||
data type is int8, float16, uint8.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4-D tensor whose data type is same as 'input_x',
|
||||
and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, ksizes, strides, rates, padding="valid"):
|
||||
"""init"""
|
||||
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
||||
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
|
||||
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
|
||||
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
|
||||
f"{arg_name}_col, 1], but got {arg_val}.")
|
||||
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
|
||||
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
|
||||
f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
|
||||
f"is {arg_val[2]}")
|
||||
|
||||
_check_tuple_or_list("ksize", ksizes, self.name)
|
||||
_check_tuple_or_list("stride", strides, self.name)
|
||||
_check_tuple_or_list("rate", rates, self.name)
|
||||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
|
||||
def infer_shape(self, input_x):
|
||||
"""infer shape"""
|
||||
in_batch, in_row, in_col, in_depth = input_x
|
||||
_, ksize_row, ksize_col, _ = self.ksizes
|
||||
_, stride_row, stride_col, _ = self.strides
|
||||
_, rate_row, rate_col, _ = self.rates
|
||||
if len(input_x) != 4:
|
||||
raise ValueError("The `input_x` should be a 4-D tensor, "
|
||||
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
|
||||
|
||||
out_batch = in_batch
|
||||
out_depth = ksize_row * ksize_col * in_depth
|
||||
|
||||
if self.padding == "VALID":
|
||||
out_row = \
|
||||
(in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
|
||||
out_col = \
|
||||
(in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
|
||||
else:
|
||||
out_row = (in_row - 1) // stride_row + 1
|
||||
out_col = (in_col - 1) // stride_col + 1
|
||||
|
||||
out_shape = [out_batch, out_row, out_col, out_depth]
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, input_x):
|
||||
"""infer dtype"""
|
||||
validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
|
||||
return input_x
|
Loading…
Reference in new issue