@ -13,7 +13,12 @@
# limitations under the License.
from __future__ import print_function
import numpy as np
from . . fluid . framework import Variable
from . . fluid . framework import unique_name
from . . fluid . framework import _current_expected_place
from . . fluid . framework import dygraph_only
from . . fluid . initializer import Constant
from . . fluid . layers import core
from . . fluid . layer_helper import LayerHelper
@ -21,20 +26,16 @@ from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtyp
from . . fluid . framework import convert_np_dtype_to_dtype_ , in_dygraph_mode , _varbase_creator , device_guard , OpProtoHolder
from . . fluid . layers import fill_constant
from paddle . common_ops_import import *
import paddle
# TODO: define functions to get create a tensor
from . . fluid . layers import crop_tensor #DEFINE_ALIAS
from . . fluid . layers import diag #DEFINE_ALIAS
from . . fluid . layers import fill_constant #DEFINE_ALIAS
from . . fluid . layers import create_tensor #DEFINE_ALIAS
from . . fluid . layers import linspace #DEFINE_ALIAS
import paddle
__all__ = [
' create_tensor ' ,
# 'create_lod_tensor',
# 'create_random_int_lodtensor',
' to_tensor ' ,
' crop_tensor ' ,
' diag ' ,
' fill_constant ' ,
@ -54,6 +55,170 @@ __all__ = [
]
@dygraph_only
def to_tensor ( data , dtype = None , place = None , stop_gradient = True ) :
"""
Constructs a ` ` paddle . Tensor ` ` or ` ` paddle . ComplexTensor ` ` from ` ` data ` ` ,
which can be scalar , tuple , list , numpy \. ndarray , paddle \. Tensor , paddle \. ComplexTensor .
If the ` ` data ` ` is already a tensor , and ` ` dtype ` ` or ` ` place ` ` does ' t change, no copy
will be performed and return origin tensor , otherwise a new tensor will be constructed
and returned . Similarly , if the data is an numpy \. ndarray of with the same ` ` dtype ` `
and the current place is cpu , no copy will be performed .
The ` ` ComplexTensor ` ` is a unique type of paddle . If x is ` ` ComplexTensor ` ` , then
` ` x . real ` ` is the real part , and ` ` x . imag ` ` is the imaginary part .
Args :
data ( scalar | tuple | list | ndarray | Tensor | ComplexTensor ) : Initial data for the tensor .
Can be a scalar , list , tuple , numpy \. ndarray , paddle \. Tensor , paddle \. ComplexTensor .
dtype ( str , optional ) : The desired data type of returned tensor . Can be ' bool ' , ' float16 ' ,
' float32 ' , ' float64 ' , ' int8 ' , ' int16 ' , ' int32 ' , ' int64 ' , ' uint8 ' . And
' complex64 ' , ' complex128 ' only for ComplexTensor .
Default : None , infers data type from ` ` data ` ` .
place ( CPUPlace | CUDAPinnedPlace | CUDAPlace , optional ) : The place to allocate Tensor . Can be
CPUPlace , CUDAPinnedPlace , CUDAPlace . Default : None , means global place .
stop_gradient ( bool , optional ) : Whether to block the gradient propagation of Autograd . Default : True .
Returns :
Tensor : A Tensor or ComplexTensor constructed from ` ` data ` ` .
Raises :
TypeError : If the data type of ` ` data ` ` is not scalar , list , tuple , numpy . ndarray , paddle . Tensor , paddle . ComplexTensor
ValueError : If ` ` data ` ` is tuple | list , it can ' t contain nested tuple|list with different lengths , such as: [[1, 2], [3, 4, 5]]
TypeError : If ` ` dtype ` ` is not bool , float16 , float32 , float64 , int8 , int16 , int32 , int64 , uint8 , complex64 , complex128
ValueError : If ` ` place ` ` is not paddle . Place , paddle . CUDAPinnedPlace , paddle . CUDAPlace
Examples :
. . code - block : : python
import paddle
import numpy as np
paddle . enable_imperative ( )
type ( paddle . to_tensor ( 1 ) )
# <class 'paddle.Tensor'>
paddle . to_tensor ( 1 )
# Tensor: generated_tensor_0
# - place: CUDAPlace(0) # allocate on global default place CPU:0
# - shape: [1]
# - layout: NCHW
# - dtype: int64_t
# - data: [1]
x = paddle . to_tensor ( 1 )
paddle . to_tensor ( x , dtype = ' int32 ' , place = paddle . CPUPlace ( ) ) # A new tensor will be constructed due to different dtype or place
# Tensor: generated_tensor_01
# - place: CPUPlace
# - shape: [1]
# - layout: NCHW
# - dtype: int
# - data: [1]
paddle . to_tensor ( ( 1.1 , 2.2 ) , place = paddle . CUDAPinnedPlace ( ) )
# Tensor: generated_tensor_1
# - place: CUDAPinnedPlace
# - shape: [2]
# - layout: NCHW
# - dtype: double
# - data: [1.1 2.2]
paddle . to_tensor ( [ [ 0.1 , 0.2 ] , [ 0.3 , 0.4 ] ] , place = paddle . CUDAPlace ( 0 ) , stop_gradient = False )
# Tensor: generated_tensor_2
# - place: CUDAPlace(0)
# - shape: [2, 2]
# - layout: NCHW
# - dtype: double
# - data: [0.1 0.2 0.3 0.4]
type ( paddle . to_tensor ( [ [ 1 + 1 j , 2 ] , [ 3 + 2 j , 4 ] ] ) , , dtype = ' complex64 ' )
# <class 'paddle.ComplexTensor'>
paddle . to_tensor ( [ [ 1 + 1 j , 2 ] , [ 3 + 2 j , 4 ] ] , dtype = ' complex64 ' )
# ComplexTensor[real]: generated_tensor_0.real
# - place: CUDAPlace(0)
# - shape: [2, 2]
# - layout: NCHW
# - dtype: float
# - data: [1 2 3 4]
# ComplexTensor[imag]: generated_tensor_0.imag
# - place: CUDAPlace(0)
# - shape: [2, 2]
# - layout: NCHW
# - dtype: float
# - data: [1 0 2 0]
"""
if place is None :
place = _current_expected_place ( )
elif not isinstance ( place ,
( core . CPUPlace , core . CUDAPinnedPlace , core . CUDAPlace ) ) :
raise ValueError (
" ' place ' must be any of paddle.Place, paddle.CUDAPinnedPlace, paddle.CUDAPlace "
)
#Todo(zhouwei): Support allocate tensor on any other specified card
if isinstance ( place , core . CUDAPlace ) and isinstance (
_current_expected_place ( ) , core . CUDAPlace ) and place . _get_device_id (
) != _current_expected_place ( ) . _get_device_id ( ) :
place = _current_expected_place ( )
if not isinstance ( data , np . ndarray ) :
if np . isscalar ( data ) and not isinstance ( data , str ) :
data = np . array ( [ data ] )
elif isinstance ( data , ( list , tuple ) ) :
data = np . array ( data )
if data . dtype == np . object :
raise ValueError (
" \n \t Faild to convert input data to a regular ndarray : \n \t - Usually "
" this means the input data contains nested lists with different lengths. "
)
elif isinstance ( data , paddle . Tensor ) :
data . stop_gradient = stop_gradient
if not data . place . _equals ( place ) :
data = data . _copy_to ( place , False )
if dtype :
if convert_dtype ( dtype ) != convert_dtype ( data . dtype ) :
return data . astype ( convert_dtype ( dtype ) )
return data
elif isinstance ( data , paddle . ComplexTensor ) :
return data
else :
raise TypeError (
" Can ' t constructs a ' paddle.Tensor ' with data type {} , data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor|paddle.ComplexTensor " .
format ( type ( data ) ) )
if dtype :
dtype = convert_dtype ( dtype )
if dtype != data . dtype :
data = data . astype ( dtype )
if not np . iscomplexobj ( data ) :
return paddle . Tensor (
value = data ,
place = place ,
persistable = False ,
zero_copy = True ,
stop_gradient = stop_gradient )
else :
name = unique_name . generate ( ' generated_tensor ' )
real_tensor = paddle . Tensor (
value = data . real ,
place = place ,
zero_copy = True ,
name = name + " .real " ,
stop_gradient = stop_gradient )
imag_tensor = paddle . Tensor (
value = data . imag ,
place = place ,
zero_copy = True ,
name = name + " .imag " ,
stop_gradient = stop_gradient )
return paddle . ComplexTensor ( real_tensor , imag_tensor )
def full_like ( x , fill_value , dtype = None , name = None ) :
"""
: alias_main : paddle . full_like
@ -201,7 +366,7 @@ def ones_like(x, dtype=None, name=None):
paddle . disable_static ( )
x = paddle . to_ variable ( np . array ( [ 1 , 2 , 3 ] , dtype = ' float32 ' ) )
x = paddle . to_ tensor ( np . array ( [ 1 , 2 , 3 ] , dtype = ' float32 ' ) )
out1 = paddle . zeros_like ( x ) # [1., 1., 1.]
out2 = paddle . zeros_like ( x , dtype = ' int32 ' ) # [1, 1, 1]
@ -291,7 +456,7 @@ def zeros_like(x, dtype=None, name=None):
paddle . disable_static ( )
x = paddle . to_ variable ( np . array ( [ 1 , 2 , 3 ] , dtype = ' float32 ' ) )
x = paddle . to_ tensor ( np . array ( [ 1 , 2 , 3 ] , dtype = ' float32 ' ) )
out1 = paddle . zeros_like ( x ) # [0., 0., 0.]
out2 = paddle . zeros_like ( x , dtype = ' int32 ' ) # [0, 0, 0]
@ -471,7 +636,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
out3 = paddle . arange ( 4.999 , dtype = ' float32 ' )
# [0., 1., 2., 3., 4.]
start_var = paddle . to_ variable ( np . array ( [ 3 ] ) )
start_var = paddle . to_ tensor ( np . array ( [ 3 ] ) )
out4 = paddle . arange ( start_var , 7 )
# [3, 4, 5, 6]
@ -713,8 +878,8 @@ def meshgrid(*args, **kwargs):
input_3 = np . random . randint ( 0 , 100 , [ 100 , ] ) . astype ( ' int32 ' )
input_4 = np . random . randint ( 0 , 100 , [ 200 , ] ) . astype ( ' int32 ' )
tensor_3 = paddle . to_ variable ( input_3 )
tensor_4 = paddle . to_ variable ( input_4 )
tensor_3 = paddle . to_ tensor ( input_3 )
tensor_4 = paddle . to_ tensor ( input_4 )
grid_x , grid_y = paddle . tensor . meshgrid ( tensor_3 , tensor_4 )
#the shape of grid_x is (100, 200)