Adapt nn.Unfold and inner.ExtractImagePatches.

check input dims for nn.LSTM.
pull/10093/head
liuxiao93 4 years ago
parent 0db846978e
commit 46b8ab3c40

@ -72,7 +72,6 @@
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h"
#include "backend/optimizer/ascend/format_type/convert_cast_format.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/pass/optimize_dependence.h"
@ -240,7 +239,6 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
auto optimizer = std::make_shared<GraphOptimizer>();
auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm");
mixed_precision_pm->AddPass(std::make_shared<InsertCast>());
mixed_precision_pm->AddPass(std::make_shared<InsertReshapeForExtractImagePatchesOp>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());

@ -1,65 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h"
#include <memory>
#include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimExtractImagePatches, Xs});
}
const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2);
MS_EXCEPTION_IF_NULL(extract);
auto in_node = extract->input(1);
MS_EXCEPTION_IF_NULL(in_node);
auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract);
auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node);
MS_EXCEPTION_IF_NULL(extract_kernel_build_info);
MS_EXCEPTION_IF_NULL(in_node_kernel_build_info);
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
in_node};
auto reshape_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0});
reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0});
reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)});
reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)});
reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type());
reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type());
reshape_builder->SetProcessor(in_node_kernel_build_info->processor());
auto reshape = func_graph->NewCNode(reshape_inputs);
reshape->set_scope(in_node->scope());
auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)},
{{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get());
AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get());
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape);
AnfAlgo::SetNodeInput(extract, reshape, 0);
return extract;
}
} // namespace opt
} // namespace mindspore

@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include "ir/anf.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass {
public:
explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true)
: PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {}
~InsertReshapeForExtractImagePatchesOp() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H

@ -563,10 +563,6 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
}
if (node->isa<CNode>() && GetCNodeName(node) == kExtractImagePatchesOpName) {
auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]};
return trans::TransShapeToDevice(shape_tmp, format);
}
return trans::TransShapeToDevice(infer_shape, format);
}

@ -720,19 +720,27 @@ class Unfold(Cell):
def __init__(self, ksizes, strides, rates, padding="valid"):
super(Unfold, self).__init__()
def _check_tuple_or_list(arg_name, arg_val, prim_name):
Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
f"{arg_name}_col, 1], but got {arg_val}.")
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
f"is {arg_val[2]}")
_check_tuple_or_list("ksize", ksizes, self.cls_name)
_check_tuple_or_list("stride", strides, self.cls_name)
_check_tuple_or_list("rate", rates, self.cls_name)
ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
strides = strides[0], strides[3], strides[1], strides[2]
rates = rates[0], rates[3], rates[1], rates[2]
self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
self.transpose = P.Transpose()
self.format_NHWC = (0, 2, 3, 1)
self.format_NCHW = (0, 3, 1, 2)
self.is_ge = context.get_context("enable_ge")
def construct(self, input_x):
if self.is_ge:
x_transpose = self.transpose(input_x, self.format_NHWC)
ret = self.extract_image_patches(x_transpose)
result = self.transpose(ret, self.format_NCHW)
else:
result = self.extract_image_patches(input_x)
result = self.extract_image_patches(input_x)
return result

@ -41,6 +41,11 @@ def _create_sequence_length(shape):
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
@constexpr
def _check_input_3d(input_shape, param_name, func_name):
if len(input_shape) != 3:
raise ValueError(f"{func_name} {param_name} should be 3d, but got shape {input_shape}")
class LSTM(Cell):
r"""
Stacked LSTM (Long Short-Term Memory) layers.
@ -237,6 +242,8 @@ class LSTM(Cell):
x = self.transpose(x, (1, 0, 2))
h, c = hx
if self.is_ascend:
_check_input_3d(F.shape(h), "h of hx", self.cls_name)
_check_input_3d(F.shape(c), "c of hx", self.cls_name)
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name)

@ -122,7 +122,7 @@ def get_bprop_extract_image_patches(self):
cast = P.Cast()
matmul = P.MatMul()
_, ksizes_row, ksizes_col, _ = self.ksizes
_, _, ksizes_row, ksizes_col = self.ksizes
def bprop(x, out, dout):
x_shape = get_shape(x)
@ -155,39 +155,6 @@ def get_bprop_extract_image_patches(self):
dx = transpose(dx, (2, 3, 0, 1))
return (dx,)
def bprop_ge(x, out, dout):
x_shape = get_shape(x)
x_batch, x_row, x_col, x_depth = x_shape
x_indices_num = x_row * x_col + 1
x_idx = F.tuple_to_array(range(1, x_indices_num))
x_idx = reshape(x_idx, (1, x_row, x_col, 1))
x_idx_patch = extract_image_patches(x_idx)
out_shape = get_shape(out)
_, out_row, out_col, _ = out_shape
out_indices_num = out_row * out_col * ksizes_row * ksizes_col
out_idx = F.tuple_to_array(range(out_indices_num))
out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
idx_tensor = reshape(idx_tensor, (-1, 2))
sp_shape = (x_indices_num, out_indices_num)
sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
grad = transpose(grad, (1, 2, 3, 4, 0, 5))
grad = reshape(grad, (-1, x_batch * x_depth))
jac = matmul(sp_tensor, grad)
dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
dx = transpose(dx, (2, 0, 1, 3))
return (dx,)
if context.get_context("enable_ge"):
return bprop_ge
return bprop

@ -31,11 +31,11 @@ class ExtractImagePatches(PrimitiveWithInfer):
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
and the format is [1, ksize_row, ksize_col, 1].
and the format is [1, 1, ksize_row, ksize_col].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
must be a tuple or list of int, and the format is [1, 1, stride_row, stride_col].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1].
pixel positions, must be a tuple or a list of integers, and the format is [1, 1, rate_row, rate_col].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
@ -58,30 +58,28 @@ class ExtractImagePatches(PrimitiveWithInfer):
def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[1] != 1:
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
f"{arg_name}_col, 1], but got {arg_val}.")
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
if not isinstance(arg_val[2], int) or not isinstance(arg_val[3], int) or arg_val[2] < 1 or arg_val[3] < 1:
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
f"is {arg_val[2]}")
f"positive integer number, but got {arg_name}_row is {arg_val[2]}, {arg_name}_col "
f"is {arg_val[3]}")
_check_tuple_or_list("ksize", ksizes, self.name)
_check_tuple_or_list("stride", strides, self.name)
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)
self.add_prim_attr("io_format", "NHWC")
self.add_prim_attr("io_format", "NCHW")
self.is_ge = context.get_context("enable_ge")
def infer_shape(self, input_x):
"""infer shape"""
in_batch, in_depth, in_row, in_col = input_x
if self.is_ge:
in_batch, in_row, in_col, in_depth = input_x
_, ksize_row, ksize_col, _ = self.ksizes
_, stride_row, stride_col, _ = self.strides
_, rate_row, rate_col, _ = self.rates
_, _, ksize_row, ksize_col = self.ksizes
_, _, stride_row, stride_col = self.strides
_, _, rate_row, rate_col = self.rates
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
@ -99,8 +97,6 @@ class ExtractImagePatches(PrimitiveWithInfer):
out_col = (in_col - 1) // stride_col + 1
out_shape = [out_batch, out_depth, out_row, out_col]
if self.is_ge:
out_shape = [out_batch, out_row, out_col, out_depth]
return out_shape
def infer_dtype(self, input_x):

@ -6405,7 +6405,7 @@ class DynamicRNN(PrimitiveWithInfer):
>>> b = Tensor(np.random.rand(128).astype(np.float16))
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> dynamic_rnn = ops.DynamicRNNN()
>>> dynamic_rnn = ops.DynamicRNN()
>>> output = dynamic_rnn(x, w, b, None, init_h, init_c)
>>> print(output[0].shape)
(2, 16, 32)

Loading…
Cancel
Save