|
|
|
@ -122,6 +122,8 @@ class TensorDataNumpy : public TensorData {
|
|
|
|
|
public:
|
|
|
|
|
explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {}
|
|
|
|
|
|
|
|
|
|
~TensorDataNumpy() override = default;
|
|
|
|
|
|
|
|
|
|
/// Total number of elements.
|
|
|
|
|
ssize_t size() const override { return buffer_.size; }
|
|
|
|
|
|
|
|
|
@ -160,7 +162,7 @@ class TensorDataNumpy : public TensorData {
|
|
|
|
|
return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// The internal buffer.
|
|
|
|
|
py::buffer_info buffer_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -258,7 +260,7 @@ py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
|
|
|
|
|
|
|
|
|
|
py::array TensorPy::AsNumpy(const Tensor &tensor) {
|
|
|
|
|
auto data_numpy = dynamic_cast<const TensorDataNumpy *>(&tensor.data());
|
|
|
|
|
if (data_numpy) {
|
|
|
|
|
if (data_numpy != nullptr) {
|
|
|
|
|
// Return internal numpy array if tensor data is implemented base on it.
|
|
|
|
|
return data_numpy->py_array();
|
|
|
|
|
}
|
|
|
|
|