!6347 support not in and add check for grad_with_sens with no sense provided
	
		
	
				
					
				
			Merge pull request !6347 from zhangbuxue/support_not_in_and_add_check_for_grad_with_sens_with_no_sense_providedpull/6347/MERGE
						commit
						fd7bcd045a
					
				| @ -0,0 +1,101 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| 
 | ||||
| """Implementation for internal polymorphism `not in` operations.""" | ||||
| 
 | ||||
| from . import _constexpr_utils as const_utils | ||||
| from ... import functional as F | ||||
| from ...composite import base | ||||
| 
 | ||||
| not_in_ = base.MultitypeFuncGraph("not_in", True) | ||||
| """ | ||||
| "not_in_" is a multi type func graph object which will determine if a not in b. | ||||
| using ".register" decorator | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| @not_in_.register("Number", "Tuple") | ||||
| def _number_not_in_tuple(x, y): | ||||
|     """ | ||||
|     Determine if a number not in tuple. | ||||
| 
 | ||||
|     Args: | ||||
|        x (Number): x | ||||
|        y (tuple): y | ||||
| 
 | ||||
|     Returns: | ||||
|        bool, if x not in y return true, x in y return false. | ||||
|    """ | ||||
|     return not const_utils.scalar_in_sequence(x, y) | ||||
| 
 | ||||
| 
 | ||||
| @not_in_.register("Number", "List") | ||||
| def _number_not_in_list(x, y): | ||||
|     """ | ||||
|     Determine if a number not in list. | ||||
| 
 | ||||
|     Args: | ||||
|        x (Number): x | ||||
|        y (list): y | ||||
| 
 | ||||
|     Returns: | ||||
|        bool, if x not in y return true, x in y return false. | ||||
|    """ | ||||
|     return not const_utils.scalar_in_sequence(x, y) | ||||
| 
 | ||||
| 
 | ||||
| @not_in_.register("String", "Tuple") | ||||
| def _string_not_in_tuple(x, y): | ||||
|     """ | ||||
|     Determine if a str not in a tuple. | ||||
| 
 | ||||
|     Args: | ||||
|        x (str): x | ||||
|        y (tuple): y | ||||
| 
 | ||||
|     Returns: | ||||
|        bool, if x not in y return true, x in y return false. | ||||
|    """ | ||||
|     return not const_utils.scalar_in_sequence(x, y) | ||||
| 
 | ||||
| 
 | ||||
| @not_in_.register("String", "List") | ||||
| def _string_not_in_list(x, y): | ||||
|     """ | ||||
|     Determine if a str not in a list. | ||||
| 
 | ||||
|     Args: | ||||
|        x (str): x | ||||
|        y (list): y | ||||
| 
 | ||||
|     Returns: | ||||
|        bool, if x not in y return true, x in y return false. | ||||
|    """ | ||||
|     return not const_utils.scalar_in_sequence(x, y) | ||||
| 
 | ||||
| 
 | ||||
| @not_in_.register("String", "Dictionary") | ||||
| def _str_not_in_dict(x, y): | ||||
|     """ | ||||
|     Determine if a str not in dict. | ||||
| 
 | ||||
|     Args: | ||||
|        x: str | ||||
|        y: dict | ||||
| 
 | ||||
|     Returns: | ||||
|        bool, if x not in y return true, x in y return false. | ||||
|    """ | ||||
|     return F.not_in_dict(x, y) | ||||
| @ -0,0 +1,56 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """ test not in""" | ||||
| import numpy as np | ||||
| 
 | ||||
| import mindspore.nn as nn | ||||
| from mindspore import context, Tensor | ||||
| 
 | ||||
| context.set_context(mode=context.GRAPH_MODE) | ||||
| 
 | ||||
| 
 | ||||
| def test_number_not_in_tuple(): | ||||
|     class Net(nn.Cell): | ||||
|         def __init__(self): | ||||
|             super(Net, self).__init__() | ||||
|             self.tuple_ = (2, 3, 4) | ||||
|             self.list_ = [2, 3, 4] | ||||
|             self.dict_ = {"a": Tensor(np.ones([1, 2, 3], np.int32)), | ||||
|                           "b": Tensor(np.ones([1, 2, 3], np.int32)), | ||||
|                           "c": Tensor(np.ones([1, 2, 3], np.int32))} | ||||
|             self.number_in = 3 | ||||
|             self.number_not_in = 5 | ||||
|             self.str_in = "a" | ||||
|             self.str_not_in = "e" | ||||
| 
 | ||||
|         def construct(self): | ||||
|             ret = 0 | ||||
|             if self.number_in not in self.tuple_: | ||||
|                 ret += 1 | ||||
|             if self.number_not_in not in self.tuple_: | ||||
|                 ret += 1 | ||||
|             if self.number_in not in self.list_: | ||||
|                 ret += 3 | ||||
|             if self.number_not_in not in self.list_: | ||||
|                 ret += 3 | ||||
|             if self.str_in not in self.dict_: | ||||
|                 ret += 5 | ||||
|             if self.str_not_in not in self.dict_: | ||||
|                 ret += 5 | ||||
|             return ret | ||||
| 
 | ||||
|     net = Net() | ||||
|     output = net() | ||||
|     assert output == 9 | ||||
					Loading…
					
					
				
		Reference in new issue