|
|
|
@ -14,18 +14,17 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""multitype_ops directory test case"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
from functools import partial, reduce
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore import dtype as mstype
|
|
|
|
|
from mindspore.ops import functional as F, composite as C
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorIntAutoCast(nn.Cell):
|
|
|
|
|
def __init__(self, ):
|
|
|
|
|
def __init__(self,):
|
|
|
|
|
super(TensorIntAutoCast, self).__init__()
|
|
|
|
|
self.i = 2
|
|
|
|
|
|
|
|
|
@ -35,7 +34,7 @@ class TensorIntAutoCast(nn.Cell):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorFPAutoCast(nn.Cell):
|
|
|
|
|
def __init__(self, ):
|
|
|
|
|
def __init__(self,):
|
|
|
|
|
super(TensorFPAutoCast, self).__init__()
|
|
|
|
|
self.f = 1.2
|
|
|
|
|
|
|
|
|
@ -45,7 +44,7 @@ class TensorFPAutoCast(nn.Cell):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorBoolAutoCast(nn.Cell):
|
|
|
|
|
def __init__(self, ):
|
|
|
|
|
def __init__(self,):
|
|
|
|
|
super(TensorBoolAutoCast, self).__init__()
|
|
|
|
|
self.f = True
|
|
|
|
|
|
|
|
|
@ -55,7 +54,7 @@ class TensorBoolAutoCast(nn.Cell):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorAutoCast(nn.Cell):
|
|
|
|
|
def __init__(self, ):
|
|
|
|
|
def __init__(self,):
|
|
|
|
|
super(TensorAutoCast, self).__init__()
|
|
|
|
|
|
|
|
|
|
def construct(self, t1, t2):
|
|
|
|
@ -65,7 +64,7 @@ class TensorAutoCast(nn.Cell):
|
|
|
|
|
|
|
|
|
|
def test_tensor_auto_cast():
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
t0 = Tensor([True, False], mstype.bool_)
|
|
|
|
|
Tensor([True, False], mstype.bool_)
|
|
|
|
|
t_uint8 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint8)
|
|
|
|
|
t_int8 = Tensor(np.ones([2, 1, 2, 2]), mstype.int8)
|
|
|
|
|
t_int16 = Tensor(np.ones([2, 1, 2, 2]), mstype.int16)
|
|
|
|
|