|
|
|
@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore import context
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
import mindspore.nn.probability as msp
|
|
|
|
|
|
|
|
|
@ -273,7 +274,8 @@ def check_type(data_type, value_type, name):
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def raise_none_error(name):
|
|
|
|
|
raise ValueError(f"{name} should be specified. Value cannot be None")
|
|
|
|
|
raise TypeError(f"the type {name} should be subclass of Tensor."
|
|
|
|
|
f" It should not be None since it is not specified during initialization.")
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def raise_not_impl_error(name):
|
|
|
|
@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, name):
|
|
|
|
|
if not isinstance(x['dtype'], tuple):
|
|
|
|
|
raise TypeError("Input type should be a tuple: " + name["value"])
|
|
|
|
|
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
|
|
|
|
|
|
|
|
|
|
out = {'shape': None,
|
|
|
|
|
'dtype': None,
|
|
|
|
|
'value': None}
|
|
|
|
|
'value': x["value"]}
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args):
|
|
|
|
|
return
|
|
|
|
|
def __call__(self, x, name):
|
|
|
|
|
if context.get_context("mode") == 0:
|
|
|
|
|
return x["value"]
|
|
|
|
|
#Pynative mode
|
|
|
|
|
if isinstance(x, tuple):
|
|
|
|
|
return x
|
|
|
|
|
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
|
|
|
|
|
|
|
|
|
|
class CheckTensor(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer):
|
|
|
|
|
'value': None}
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args):
|
|
|
|
|
def __call__(self, x, name):
|
|
|
|
|
return
|
|
|
|
|