!106 support comparison ops for python
Merge pull request !106 from amongo/SupportPythonOperatorspull/106/MERGE
commit
a5d95e472e
@ -0,0 +1,53 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""greater_equal_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
|
||||
# using ".register" decorator
|
||||
greater_equal = base.MultitypeFuncGraph("greater_equal")
|
||||
|
||||
|
||||
@greater_equal.register("Number", "Number")
|
||||
def _greater_equal_scala(x, y):
|
||||
"""
|
||||
Determine whether x is greater equal than y
|
||||
|
||||
Args:
|
||||
x(Number): Number.
|
||||
y(Number): Number.
|
||||
|
||||
Returns:
|
||||
bool, if x >= y return true, x < y return false.
|
||||
"""
|
||||
return F.scalar_ge(x, y)
|
||||
|
||||
@greater_equal.register("Tensor", "Number")
|
||||
@greater_equal.register("Number", "Tensor")
|
||||
@greater_equal.register("Tensor", "Tensor")
|
||||
def _greater_equal_tensor(x, y):
|
||||
"""
|
||||
Determine whether tensor x is greater equal than tensor y elementwise
|
||||
|
||||
Args:
|
||||
x(Tensor): Tensor.
|
||||
y(Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, return value by operator P.GreaterEqual.
|
||||
"""
|
||||
return F.tensor_ge(x, y)
|
@ -0,0 +1,53 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Ungreater required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""equal_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
|
||||
# using ".register" decorator
|
||||
greater = base.MultitypeFuncGraph("greater")
|
||||
|
||||
|
||||
@greater.register("Number", "Number")
|
||||
def _greater_scala(x, y):
|
||||
"""
|
||||
Determine whether two numbers are greater.
|
||||
|
||||
Args:
|
||||
x(Number): Number.
|
||||
y(Number): Number.
|
||||
|
||||
Returns:
|
||||
bool, if x > y return true, x <= y return false.
|
||||
"""
|
||||
return F.scalar_gt(x, y)
|
||||
|
||||
@greater.register("Tensor", "Number")
|
||||
@greater.register("Number", "Tensor")
|
||||
@greater.register("Tensor", "Tensor")
|
||||
def _greater_tensor(x, y):
|
||||
"""
|
||||
Determine whether two tensor are greater by element.
|
||||
|
||||
Args:
|
||||
x(Tensor): Tensor.
|
||||
y(Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
tensor, return operation of x and y by P.Greater
|
||||
"""
|
||||
return F.tensor_gt(x, y)
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""logical_not_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
# logical_not is a metagraph object which will generate function according to input type
|
||||
# using ".register" decorator
|
||||
logical_not = base.MultitypeFuncGraph("logical_not")
|
||||
|
||||
|
||||
@logical_not.register("Number")
|
||||
def _logical_not_scala(x):
|
||||
"""
|
||||
Return logical not operation result of x
|
||||
|
||||
Args:
|
||||
x(Number): Number.
|
||||
|
||||
Returns:
|
||||
bool, Return logical not operation result of x
|
||||
"""
|
||||
return F.bool_not(x.__bool__())
|
||||
|
||||
|
||||
@logical_not.register("Tensor")
|
||||
def _logical_not_tensor(x):
|
||||
"""
|
||||
Return logical not operation result of x
|
||||
Args:
|
||||
x(Tensor): Tensor.
|
||||
Returns:
|
||||
Tensor, Return logical not operation result of x
|
||||
"""
|
||||
return F.logical_not(x)
|
@ -0,0 +1,237 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""not_equal_impl"""
|
||||
|
||||
from ...composite import base
|
||||
from ... import functional as F
|
||||
|
||||
|
||||
not_equal = base.MultitypeFuncGraph("not_equal")
|
||||
"""
|
||||
not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type
|
||||
using ".register" decorator
|
||||
"""
|
||||
|
||||
|
||||
@not_equal.register("Number", "Number")
|
||||
def _not_equal_scalar(x, y):
|
||||
"""
|
||||
Determine if two numbers is not equal.
|
||||
|
||||
Args:
|
||||
x (Number): x
|
||||
y (NUmber): y
|
||||
|
||||
Returns:
|
||||
bool, if x != y return true, x == y return false.
|
||||
"""
|
||||
return not F.scalar_eq(x, y)
|
||||
|
||||
|
||||
@not_equal.register("String", "String")
|
||||
def _not_equal_string(x, y):
|
||||
"""
|
||||
Determine if two strings are not equal.
|
||||
|
||||
Args:
|
||||
x: str
|
||||
y: str
|
||||
|
||||
Returns:
|
||||
bool, if x != y return true, x == y return false.
|
||||
"""
|
||||
return not F.string_eq(x, y)
|
||||
|
||||
|
||||
@not_equal.register("String", "None")
|
||||
def _string_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if string not equals none.
|
||||
|
||||
Args:
|
||||
x: str.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "String")
|
||||
def _none_not_equal_string(x, y):
|
||||
"""
|
||||
Determine if string not equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: str.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "None")
|
||||
def _none_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if none not equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return False.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@not_equal.register("Number", "None")
|
||||
def _scalar_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if number not equals none.
|
||||
|
||||
Args:
|
||||
x: Number.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "Number")
|
||||
def _none_not_equal_scalar(x, y):
|
||||
"""
|
||||
Determine if number not_equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: NUmber.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("Tuple", "Tuple")
|
||||
def _euqal_tuple(x, y):
|
||||
"""
|
||||
Determine if two tuples are not equal by element.
|
||||
|
||||
Args:
|
||||
x (tuple): x
|
||||
y (tuple): y
|
||||
|
||||
Returns:
|
||||
bool, if x and y are not equal by element return true, else return false.
|
||||
"""
|
||||
return not F.tuple_equal(x, y)
|
||||
|
||||
|
||||
@not_equal.register("List", "List")
|
||||
def _euqal_list(x, y):
|
||||
"""
|
||||
Determine if two lists are not equal by element.
|
||||
|
||||
Args:
|
||||
x (list): x
|
||||
y (list): y
|
||||
|
||||
Returns:
|
||||
bool, if x and y are not equal by element return true, else return false.
|
||||
"""
|
||||
return not F.list_equal(x, y)
|
||||
|
||||
|
||||
@not_equal.register("Tuple", "None")
|
||||
def _tuple_euqal_none(x, y):
|
||||
"""
|
||||
Determine if tuple element not equals none element.
|
||||
|
||||
Args:
|
||||
x: Tuple.
|
||||
y: None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "Tuple")
|
||||
def _none_not_equal_tuple(x, y):
|
||||
"""
|
||||
Determine if tuple element not equals none element.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: Tuple.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
@not_equal.register("Tensor", "Number")
|
||||
@not_equal.register("Number", "Tensor")
|
||||
@not_equal.register("Tensor", "Tensor")
|
||||
def _tensor_not_equal_tensor(x, y):
|
||||
"""
|
||||
Determine if two tensors are not_equal.
|
||||
|
||||
Args:
|
||||
x : Tensor.
|
||||
y : Tensor.
|
||||
|
||||
Returns:
|
||||
bool, if x == y return true, x != y return false.
|
||||
"""
|
||||
return F.not_equal(x, y)
|
||||
|
||||
|
||||
@not_equal.register("Tensor", "None")
|
||||
def _tensor_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if tensor not_equal none.
|
||||
|
||||
Args:
|
||||
x : Tensor.
|
||||
y : None.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "Tensor")
|
||||
def _none_not_equal_tensor(x, y):
|
||||
"""
|
||||
Determine if tensor not equal none.
|
||||
|
||||
Args:
|
||||
x : None.
|
||||
y : Tensor.
|
||||
|
||||
Returns:
|
||||
bool, return True.
|
||||
"""
|
||||
return True
|
@ -0,0 +1,26 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""uadd_impl"""
|
||||
from mindspore.ops.composite import base
|
||||
|
||||
# uadd is a metagraph object which will return operation result regarding input
|
||||
# using ".register" decorator
|
||||
uadd = base.MultitypeFuncGraph("uadd")
|
||||
|
||||
@uadd.register("Tensor")
|
||||
@uadd.register("Number")
|
||||
def _uadd_scala(x):
|
||||
return x
|
Loading…
Reference in new issue