You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.2 KiB
102 lines
3.2 KiB
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "paddle/fluid/framework/dlpack_tensor.h"
|
|
#include <glog/logging.h>
|
|
#include <gtest/gtest.h>
|
|
#include <vector>
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
namespace { // NOLINT
|
|
template <typename T>
|
|
constexpr uint8_t GetDLDataTypeCode() {
|
|
return std::is_same<platform::float16, T>::value ||
|
|
std::is_floating_point<T>::value
|
|
? static_cast<uint8_t>(kDLFloat)
|
|
: (std::is_unsigned<T>::value
|
|
? static_cast<uint8_t>(kDLUInt)
|
|
: (std::is_integral<T>::value ? static_cast<uint8_t>(kDLInt)
|
|
: static_cast<uint8_t>(-1)));
|
|
}
|
|
} // NOLINT
|
|
|
|
template <typename T>
|
|
void TestMain(const platform::Place &place, uint16_t lanes) {
|
|
DDim dims{4, 5, 6, 7};
|
|
Tensor tensor;
|
|
tensor.Resize(dims);
|
|
void *p = tensor.mutable_data<T>(place);
|
|
|
|
DLPackTensor dlpack_tensor(tensor, lanes);
|
|
::DLTensor &dl_tensor = dlpack_tensor;
|
|
|
|
CHECK_EQ(p, dl_tensor.data);
|
|
if (platform::is_cpu_place(place)) {
|
|
CHECK_EQ(kDLCPU, dl_tensor.ctx.device_type);
|
|
CHECK_EQ(0, dl_tensor.ctx.device_id);
|
|
} else if (platform::is_gpu_place(place)) {
|
|
CHECK_EQ(kDLGPU, dl_tensor.ctx.device_type);
|
|
CHECK_EQ(boost::get<platform::CUDAPlace>(place).device,
|
|
dl_tensor.ctx.device_id);
|
|
} else if (platform::is_cuda_pinned_place(place)) {
|
|
CHECK_EQ(kDLCPUPinned, dl_tensor.ctx.device_type);
|
|
CHECK_EQ(0, dl_tensor.ctx.device_id);
|
|
} else {
|
|
CHECK_EQ(false, true);
|
|
}
|
|
|
|
CHECK_EQ(dims.size(), dl_tensor.ndim);
|
|
for (auto i = 0; i < dims.size(); ++i) {
|
|
CHECK_EQ(dims[i], dl_tensor.shape[i]);
|
|
}
|
|
|
|
CHECK_EQ(dl_tensor.strides == nullptr, true);
|
|
CHECK_EQ(static_cast<uint64_t>(0), dl_tensor.byte_offset);
|
|
|
|
CHECK_EQ(lanes, dl_tensor.dtype.lanes);
|
|
CHECK_EQ(sizeof(T) * 8, dl_tensor.dtype.bits);
|
|
|
|
CHECK_EQ(GetDLDataTypeCode<T>(), dl_tensor.dtype.code);
|
|
}
|
|
|
|
template <typename T>
|
|
void TestMainLoop() {
|
|
#ifdef PADDLE_WITH_CUDA
|
|
std::vector<platform::Place> places{platform::CPUPlace(),
|
|
platform::CUDAPlace(0),
|
|
platform::CUDAPinnedPlace()};
|
|
if (platform::GetCUDADeviceCount() > 1) {
|
|
places.emplace_back(platform::CUDAPlace(1));
|
|
}
|
|
#else
|
|
std::vector<platform::Place> places{platform::CPUPlace()};
|
|
#endif
|
|
std::vector<uint16_t> lanes{1, 2};
|
|
for (auto &p : places) {
|
|
for (auto &l : lanes) {
|
|
TestMain<T>(p, l);
|
|
}
|
|
}
|
|
}
|
|
TEST(dlpack, test_all) {
|
|
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
|
|
|
|
_ForEachDataType_(TestCallback);
|
|
}
|
|
|
|
} // namespace framework
|
|
} // namespace paddle
|