|
|
@ -14,15 +14,15 @@
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
"""Utitly functions to help distribution class."""
|
|
|
|
"""Utitly functions to help distribution class."""
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from mindspore.ops import _utils as utils
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
|
|
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import _utils as utils
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
|
|
|
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn.probability as msp
|
|
|
|
import mindspore.nn.probability as msp
|
|
|
|
|
|
|
|
|
|
|
@ -82,6 +82,24 @@ def convert_to_batch(t, batch_shape, required_type):
|
|
|
|
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
|
|
|
|
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cast_type_for_device(dtype):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
use the alternative dtype supported by the device.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
dtype (mindspore.dtype): input dtype.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
mindspore.dtype.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if context.get_context("device_target") == "GPU":
|
|
|
|
|
|
|
|
if dtype in mstype.uint_type or dtype == mstype.int8:
|
|
|
|
|
|
|
|
return mstype.int16
|
|
|
|
|
|
|
|
if dtype == mstype.int64:
|
|
|
|
|
|
|
|
return mstype.int32
|
|
|
|
|
|
|
|
if dtype == mstype.float64:
|
|
|
|
|
|
|
|
return mstype.float32
|
|
|
|
|
|
|
|
return dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_scalar_from_param(params):
|
|
|
|
def check_scalar_from_param(params):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Check if params are all scalars.
|
|
|
|
Check if params are all scalars.
|
|
|
@ -293,10 +311,10 @@ def raise_not_impl_error(name):
|
|
|
|
def check_distribution_name(name, expected_name):
|
|
|
|
def check_distribution_name(name, expected_name):
|
|
|
|
if name is None:
|
|
|
|
if name is None:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
f"Distribution should be a constant which is not None.")
|
|
|
|
f"Input dist should be a constant which is not None.")
|
|
|
|
if name != expected_name:
|
|
|
|
if name != expected_name:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
f"Expected distribution name is {expected_name}, but got {name}.")
|
|
|
|
f"Expected dist input is {expected_name}, but got {name}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CheckTuple(PrimitiveWithInfer):
|
|
|
|
class CheckTuple(PrimitiveWithInfer):
|
|
|
|