change Split op to extend PrimitiveWithCheck

pull/8669/head
tom__chen 4 years ago
parent 77ba75ba75
commit 4fa6238dde

@ -132,10 +132,6 @@ class SplitGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_];
return false;
}
if (input_shape[axis_] % output_num_ != 0) {
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_];
return false;
}
if (output_num_ != output_num) {
MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_;
return false;

@ -259,6 +259,8 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

@ -695,5 +695,48 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, shape_min, shape_max));
return ret;
}
AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
ShapeVector x_shape = input_x->shape()->shape();
ShapeVector x_shape_min = input_x->shape()->min_shape();
if (x_shape_min.empty()) {
x_shape_min = x_shape;
}
ShapeVector x_shape_max = input_x->shape()->max_shape();
if (x_shape_max.empty()) {
x_shape_max = x_shape;
}
int64_t rank = SizeToLong(x_shape.size());
ValuePtr axis = primitive->GetAttr("axis");
int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank);
axis_value = GetPositiveAxis(axis_value, LongToSize(rank));
int64_t output_num_value = primitive->GetAttr("output_num")->cast<Int64ImmPtr>()->value();
if ((x_shape[axis_value] != Shape::SHP_ANY) && (x_shape[axis_value] % output_num_value != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << axis_value << "] = " << x_shape[axis_value]
<< " must be divisible by output_num = " << output_num_value;
}
ShapeVector output_shape = x_shape;
if (output_shape[axis_value] != Shape::SHP_ANY) {
output_shape[axis_value] = static_cast<int>(x_shape[axis_value] / output_num_value);
}
ShapeVector output_shape_min = x_shape_min;
output_shape_min[axis_value] = static_cast<int>(x_shape_min[axis_value] / output_num_value);
ShapeVector output_shape_max = x_shape_max;
output_shape_max[axis_value] = static_cast<int>(x_shape_max[axis_value] / output_num_value);
AbstractBasePtrList output_list;
for (int64_t i = 0; i < output_num_value; ++i) {
auto output = input_x->Broaden();
output->set_shape(std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max));
output_list.push_back(output);
}
return std::make_shared<AbstractTuple>(output_list);
}
} // namespace abstract
} // namespace mindspore

@ -69,6 +69,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},

@ -118,6 +118,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("Dynam
inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad");
inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd");
inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate");
inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split");
// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");

@ -873,18 +873,17 @@ class UniqueWithPad(PrimitiveWithInfer):
return out
class Split(PrimitiveWithInfer):
class Split(PrimitiveWithCheck):
"""
Splits the input tensor into output_num of tensors along the given axis and output numbers.
Args:
axis (int): Index of the split position. Default: 0.
output_num (int): The number of output tensors. Default: 1.
output_num (int): The number of output tensors. Must be postive int. Default: 1.
Raises:
ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)),
or if the `output_num` is less than or equal to 0, or if the
dimension which to split cannot be evenly divided by `output_num`.
or if the `output_num` is less than or equal to 0.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@ -913,32 +912,15 @@ class Split(PrimitiveWithInfer):
"""Initialize Split"""
validator.check_value_type("axis", axis, [int], self.name)
validator.check_value_type("output_num", output_num, [int], self.name)
validator.check_positive_int(output_num, "output_num", self.name)
self.axis = axis
self.output_num = output_num
def __infer__(self, x):
def __check__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
validator.check_positive_int(self.output_num, "output_num", self.name)
output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0:
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
f" output_num {self.output_num}")
x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
out_shapes = []
out_dtypes = []
for _ in range(self.output_num):
out_shapes.append(tuple(x_shape))
out_dtypes.append(x['dtype'])
out_shapes = tuple(out_shapes)
out_dtypes = tuple(out_dtypes)
out = {'shape': out_shapes,
'dtype': out_dtypes,
'value': None}
return out
class Rank(PrimitiveWithInfer):

@ -18,6 +18,7 @@ import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import operations as P
@ -30,6 +31,18 @@ class Net(nn.Cell):
return self.split(x)
class NetDynamic(nn.Cell):
def __init__(self, axis=0, out_nums=1):
super(NetDynamic, self).__init__()
self.conv = inner.GpuConvertToDynamicShape()
self.split = P.Split(axis, out_nums)
def construct(self, x):
x_conv = self.conv(x)
x_split = self.split(x_conv)
return x_split
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -47,6 +60,9 @@ def test_split():
assert (out.asnumpy() == x[i]).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_4d():
x_np = np.random.randn(2, 6, 4, 4).astype(np.float32)
y = np.split(x_np, 3, axis=1)
@ -56,3 +72,69 @@ def test_split_4d():
for i, out in enumerate(outputs):
assert (out.asnumpy() == y[i]).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_dynamic():
x = np.array([[[1, -1, 1], [2, -2, 2]],
[[3, -3, 3], [4, -4, 4]],
[[5, -5, 5], [6, -6, 6]]]).astype(np.float32)
net = NetDynamic(0, 3)
x_split = net(Tensor(x))
for i, out in enumerate(x_split):
assert (out.asnumpy() == x[i]).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_dynamic_axis1():
x = np.array([[[1, -1, 1], [2, -2, 2]],
[[3, -3, 3], [4, -4, 4]],
[[5, -5, 5], [6, -6, 6]]]).astype(np.int32)
y = np.split(x, 2, axis=1)
net = NetDynamic(1, 2)
x_split = net(Tensor(x))
for i, out in enumerate(x_split):
assert (out.asnumpy() == y[i]).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_dynamic_axis2():
x = np.array([[[1, -1, 1], [2, -2, 2]],
[[3, -3, 3], [4, -4, 4]],
[[5, -5, 5], [6, -6, 6]]]).astype(np.int32)
y = np.split(x, 3, axis=2)
net = NetDynamic(2, 3)
x_split = net(Tensor(x))
for i, out in enumerate(x_split):
assert (out.asnumpy() == y[i]).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_invalid_input():
with pytest.raises(TypeError):
_ = Net(0.1, 3)
with pytest.raises(TypeError):
_ = Net(0, 3.0)
with pytest.raises(ValueError):
_ = Net(0, -3)
x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
split_net = Net(2, 2)
with pytest.raises(ValueError):
_ = split_net(Tensor(x))
with pytest.raises(TypeError):
_ = split_net(x)

Loading…
Cancel
Save