@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) {
// TensorDataNumpy implements TensorData using numpy array.
class TensorDataNumpy : public TensorData {
public :
explicit TensorDataNumpy ( const py : : array & input ) : data_ ( input ) {
if ( ! IsCContiguous ( data_ ) ) {
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
auto np = py : : module : : import ( " numpy " ) ;
auto convert = np . attr ( " ascontiguousarray " ) ;
data_ = convert ( data_ ) ;
}
}
explicit TensorDataNumpy ( py : : buffer_info & & buffer ) : buffer_ ( std : : move ( buffer ) ) { }
/// Total number of elements.
ssize_t size ( ) const override { return data_. size ( ) ; }
ssize_t size ( ) const override { return buffer_ . size ; }
/// Byte size of a single element.
ssize_t itemsize ( ) const override { return data_. itemsize ( ) ; }
ssize_t itemsize ( ) const override { return buffer_ . itemsize ; }
/// Total number of bytes.
ssize_t nbytes ( ) const override { return data_. nbytes ( ) ; }
ssize_t nbytes ( ) const override { return buffer_. itemsize * buffer_ . size ; }
/// Number of dimensions.
ssize_t ndim ( ) const override { return data_. ndim ( ) ; }
ssize_t ndim ( ) const override { return buffer_. ndim ; }
/// Data pointer.
void * data ( ) override { return data_ . request ( ) . ptr ; }
const void * const_data ( ) const override { return data_ . request ( ) . ptr ; }
/// Is data equals.
bool equals ( const TensorData & other ) const override {
auto ptr = dynamic_cast < const TensorDataNumpy * > ( & other ) ;
if ( ptr = = nullptr ) {
// Not same type, compare data byte by byte.
return TensorData : : equals ( other ) ;
}
return NumpyEquals ( * ptr ) ;
}
void * data ( ) override { return buffer_ . ptr ; }
bool NumpyEquals ( const TensorDataNumpy & other ) const {
auto all_data_equal = [ & other , this ] ( ) - > bool {
auto np = py : : module : : import ( " numpy " ) ;
auto equal = np . attr ( " equal " ) ( data_ , other . data_ ) ;
auto all_equal = np . attr ( " all " ) ( equal ) ;
return all_equal . cast < bool > ( ) ;
} ;
return this = = & other | | data_ . is ( other . data_ ) | | all_data_equal ( ) ;
}
const void * const_data ( ) const override { return buffer_ . ptr ; }
/// To string.
std : : string ToString ( const TypeId type , const ShapeVector & shape , bool use_comma ) const override {
@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData {
kwargs [ " separator " ] = " , " ;
auto np = py : : module : : import ( " numpy " ) ;
auto array2string = np . attr ( " array2string " ) ;
return py : : str ( array2string ( data_ , * * kwargs ) ) ;
return py : : str ( array2string ( py_array( ) , * * kwargs ) ) ;
}
// without comma.
return py : : str ( data_ ) ;
return py : : str ( py_array( ) ) ;
}
/// py::array object.
py : : array py_array ( ) const { return data_ ; }
py : : array py_array ( ) const {
// Use dummy owner to avoid copy data.
py : : str dummyOwner ;
return py : : array ( py : : dtype ( buffer_ ) , buffer_ . shape , buffer_ . strides , buffer_ . ptr , dummyOwner ) ;
}
private :
mutable py : : array data_ ;
py : : buffer_info buffer _;
} ;
TensorPtr TensorPy : : MakeTensor ( const py : : array & input , const TypePtr & type_ptr ) {
@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
/// Creates a Tensor from a numpy array without copy
TensorPtr TensorPy : : MakeTensorNoCopy ( const py : : array & input ) {
// Check format.
if ( ! IsCContiguous ( input ) ) {
MS_LOG ( EXCEPTION ) < < " Array should be C contiguous. " ;
}
// Get input buffer info.
py : : buffer_info buf = input . request ( ) ;
// Get tensor dtype and check it.
@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Get tensor shape.
ShapeVector shape ( buf . shape . begin ( ) , buf . shape . end ( ) ) ;
// Make a tensor with shared data with numpy array.
auto tensor_data = std : : make_shared < TensorDataNumpy > ( input ) ;
auto tensor_data = std : : make_shared < TensorDataNumpy > ( std: : move ( buf ) ) ;
return std : : make_shared < Tensor > ( dtype , shape , tensor_data ) ;
}