From b9c9046b93589df9580c9647b5ea56033c5bbe06 Mon Sep 17 00:00:00 2001 From: buxue Date: Fri, 11 Sep 2020 17:13:36 +0800 Subject: [PATCH] support function as condition of if --- mindspore/_extends/parse/standard_method.py | 5 +++ mindspore/ccsrc/pipeline/jit/resource.cc | 4 ++ .../python/pipeline/parse/test_if_function.py | 39 +++++++++++++++++++ tests/ut/python/runtest.sh | 14 +++++-- 4 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 tests/ut/python/pipeline/parse/test_if_function.py diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index c99024dcce..6530ec6141 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -275,6 +275,11 @@ def none_bool(x): return False +def func_bool(x): + """Implementation of `func_bool`.""" + return True + + def float_floordiv(x, y): """Implementation of `float_floordiv`.""" return floor(x / y) diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index aff91ae22a..ede5c59bae 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -37,6 +37,10 @@ BuiltInTypeMap &GetMethodMap() { { {"__bool__", std::string("none_bool")} // C.none_bool }}, + {kObjectTypeFunction, + { + {"__bool__", std::string("func_bool")} // C.str_bool + }}, {kNumberTypeBool, { {"__and__", prim::kPrimBoolAnd}, // P.bool_and diff --git a/tests/ut/python/pipeline/parse/test_if_function.py b/tests/ut/python/pipeline/parse/test_if_function.py new file mode 100644 index 0000000000..3ba3eb64ed --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_if_function.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test if function""" +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + +def test_if_function(): + class Net(nn.Cell): + def __init__(self, func): + super(Net, self).__init__() + self.func = func + + def construct(self, x, y): + if self.func: + return self.func(x, y) + return x - y + def add(x, y): + return x + y + net = Net(add) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + net(x, y) diff --git a/tests/ut/python/runtest.sh b/tests/ut/python/runtest.sh index 10a6fafbe1..763d9d2a5e 100755 --- a/tests/ut/python/runtest.sh +++ b/tests/ut/python/runtest.sh @@ -36,28 +36,33 @@ if [ $# -eq 1 ] && ([ "$1" == "stage1" ] || [ "$1" == "stage2" ] || [ "$1" == elif [ $1 == "stage2" ]; then echo "run python parallel\train\ops ut" - pytest -n 4 --dist=loadfile -v $CURRPATH/parallel $CURRPATH/train $CURRPATH/ops + pytest -n 4 --dist=loadfile -v $CURRPATH/parallel $CURRPATH/train + RET=$? + if [ ${RET} -ne 0 ]; then + exit ${RET} + fi + + pytest -n 2 --dist=loadfile -v $CURRPATH/ops elif [ $1 == "stage3" ]; then echo "run other ut" pytest --ignore=$CURRPATH/dataset --ignore=$CURRPATH/parallel --ignore=$CURRPATH/train --ignore=$CURRPATH/ops --ignore=$CURRPATH/pynative_mode $IGNORE_EXEC $CURRPATH - RET=$? if [ ${RET} -ne 0 ]; then exit ${RET} fi + pytest $CURRPATH/pynative_mode fi else echo "run all python ut" pytest $CURRPATH/dataset - RET=$? if [ ${RET} -ne 0 ]; then exit ${RET} fi - pytest -n 4 --dist=loadfile -v $CURRPATH/parallel $CURRPATH/train $CURRPATH/ops + pytest -n 4 --dist=loadfile -v $CURRPATH/parallel $CURRPATH/train $CURRPATH/ops RET=$? if [ ${RET} -ne 0 ]; then exit ${RET} @@ -68,6 +73,7 @@ else if [ ${RET} -ne 0 ]; then exit ${RET} fi + pytest $CURRPATH/pynative_mode fi