|
|
|
@ -1,4 +1,3 @@
|
|
|
|
|
|
|
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
@ -23,7 +22,7 @@ from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
|
|
|
|
|
def cast_to_tensor(t, dtype=mstype.float32):
|
|
|
|
|
def cast_to_tensor(t, hint_dtype=mstype.float32):
|
|
|
|
|
"""
|
|
|
|
|
Cast an user input value into a Tensor of dtype.
|
|
|
|
|
If the input t is of type Parameter, t is directly returned as a Parameter.
|
|
|
|
@ -41,25 +40,26 @@ def cast_to_tensor(t, dtype=mstype.float32):
|
|
|
|
|
if isinstance(t, Parameter):
|
|
|
|
|
return t
|
|
|
|
|
if isinstance(t, Tensor):
|
|
|
|
|
if t.dtype != hint_dtype:
|
|
|
|
|
raise TypeError(f"Input tensor should be type {hint_dtype}.")
|
|
|
|
|
#check if the Tensor in shape of Tensor(4)
|
|
|
|
|
if t.dim() == 0:
|
|
|
|
|
value = t.asnumpy()
|
|
|
|
|
return Tensor([t], dtype=dtype)
|
|
|
|
|
return Tensor([value], dtype=hint_dtype)
|
|
|
|
|
#convert the type of tensor to dtype
|
|
|
|
|
t.set_dtype(dtype)
|
|
|
|
|
return t
|
|
|
|
|
if isinstance(t, (list, np.ndarray)):
|
|
|
|
|
return Tensor(t, dtype=dtype)
|
|
|
|
|
return Tensor(t, dtype=hint_dtype)
|
|
|
|
|
if np.isscalar(t):
|
|
|
|
|
return Tensor([t], dtype=dtype)
|
|
|
|
|
return Tensor([t], dtype=hint_dtype)
|
|
|
|
|
raise RuntimeError("Input type is not supported.")
|
|
|
|
|
|
|
|
|
|
def convert_to_batch(t, batch_shape, dtype):
|
|
|
|
|
def convert_to_batch(t, batch_shape, hint_dtype):
|
|
|
|
|
"""
|
|
|
|
|
Convert a Tensor to a given batch shape.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
t (Tensor, Parameter): Tensor to be converted.
|
|
|
|
|
t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted.
|
|
|
|
|
batch_shape (tuple): desired batch shape.
|
|
|
|
|
dtype (mindspore.dtype): desired dtype.
|
|
|
|
|
|
|
|
|
@ -71,9 +71,8 @@ def convert_to_batch(t, batch_shape, dtype):
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(t, Parameter):
|
|
|
|
|
return t
|
|
|
|
|
if isinstance(t, Tensor):
|
|
|
|
|
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=dtype)
|
|
|
|
|
return Tensor(np.broadcast_to(t, batch_shape), dtype=dtype)
|
|
|
|
|
t = cast_to_tensor(t, hint_dtype)
|
|
|
|
|
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=hint_dtype)
|
|
|
|
|
|
|
|
|
|
def check_scalar_from_param(params):
|
|
|
|
|
"""
|
|
|
|
@ -85,6 +84,8 @@ def check_scalar_from_param(params):
|
|
|
|
|
Notes: String parameters are excluded.
|
|
|
|
|
"""
|
|
|
|
|
for value in params.values():
|
|
|
|
|
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
|
|
|
|
|
return params['distribution'].is_scalar_batch
|
|
|
|
|
if isinstance(value, Parameter):
|
|
|
|
|
return False
|
|
|
|
|
if isinstance(value, (str, type(params['dtype']))):
|
|
|
|
@ -108,6 +109,8 @@ def calc_broadcast_shape_from_param(params):
|
|
|
|
|
"""
|
|
|
|
|
broadcast_shape = []
|
|
|
|
|
for value in params.values():
|
|
|
|
|
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
|
|
|
|
|
return params['distribution'].broadcast_shape
|
|
|
|
|
if isinstance(value, (str, type(params['dtype']))):
|
|
|
|
|
continue
|
|
|
|
|
if value is None:
|
|
|
|
@ -251,3 +254,7 @@ def check_tensor_type(name, inputs, valid_type):
|
|
|
|
|
inputs = P.DType()(inputs)
|
|
|
|
|
if inputs not in valid_type:
|
|
|
|
|
raise TypeError(f"{name} dtype is invalid")
|
|
|
|
|
|
|
|
|
|
def check_type(data_type, value_type, name):
|
|
|
|
|
if not data_type in value_type:
|
|
|
|
|
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid")
|
|
|
|
|