|
|
|
@ -25,7 +25,7 @@
|
|
|
|
|
|
|
|
|
|
using namespace mindspore::dataset;
|
|
|
|
|
|
|
|
|
|
class MindDataTestTensorDE : public UT::Common {
|
|
|
|
|
class MindDataTestTensorDE : public mindspore::Common {
|
|
|
|
|
public:
|
|
|
|
|
MindDataTestTensorDE() {}
|
|
|
|
|
};
|
|
|
|
@ -42,7 +42,7 @@ TEST_F(MindDataTestTensorDE, MSTensorConvertToLiteTensor) {
|
|
|
|
|
std::shared_ptr<mindspore::tensor::MSTensor> lite_ms_tensor = std::shared_ptr<mindspore::tensor::MSTensor>(
|
|
|
|
|
std::dynamic_pointer_cast<mindspore::tensor::DETensor>(ms_tensor)->ConvertToLiteTensor());
|
|
|
|
|
// check if the lite_ms_tensor is the derived LiteTensor
|
|
|
|
|
mindspore::tensor::LiteTensor * lite_tensor = static_cast<mindspore::tensor::LiteTensor *>(lite_ms_tensor.get());
|
|
|
|
|
mindspore::lite::tensor::LiteTensor * lite_tensor = static_cast<mindspore::lite::tensor::LiteTensor *>(lite_ms_tensor.get());
|
|
|
|
|
ASSERT_EQ(lite_tensor != nullptr, true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -77,7 +77,7 @@ TEST_F(MindDataTestTensorDE, MSTensorDataType) {
|
|
|
|
|
TEST_F(MindDataTestTensorDE, MSTensorMutableData) {
|
|
|
|
|
std::vector<float> x = {2.5, 2.5, 2.5, 2.5};
|
|
|
|
|
std::shared_ptr<Tensor> t;
|
|
|
|
|
Tensor::CreateTensor(&t, x, TensorShape({2, 2}));
|
|
|
|
|
Tensor::CreateFromVector(x, TensorShape({2, 2}), &t);
|
|
|
|
|
auto ms_tensor = std::shared_ptr<mindspore::tensor::MSTensor>(new mindspore::tensor::DETensor(t));
|
|
|
|
|
float *data = static_cast<float*>(ms_tensor->MutableData());
|
|
|
|
|
std::vector<float> tensor_vec(data, data + ms_tensor->ElementsNum());
|
|
|
|
@ -88,7 +88,7 @@ TEST_F(MindDataTestTensorDE, MSTensorMutableData) {
|
|
|
|
|
TEST_F(MindDataTestTensorDE, MSTensorHash) {
|
|
|
|
|
std::vector<float> x = {2.5, 2.5, 2.5, 2.5};
|
|
|
|
|
std::shared_ptr<Tensor> t;
|
|
|
|
|
Tensor::CreateTensor(&t, x, TensorShape({2, 2}));
|
|
|
|
|
Tensor::CreateFromVector(x, TensorShape({2, 2}), &t);
|
|
|
|
|
auto ms_tensor = std::shared_ptr<mindspore::tensor::MSTensor>(new mindspore::tensor::DETensor(t));
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); // arm64
|