|
|
|
@ -257,40 +257,29 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
|
|
|
|
|
EXPECT_GT(size, 0UL);
|
|
|
|
|
EXPECT_EQ(size, ref_size);
|
|
|
|
|
EXPECT_EQ(out.dtype, ref_out.dtype);
|
|
|
|
|
switch (out.dtype) {
|
|
|
|
|
case PaddleDType::INT64: {
|
|
|
|
|
int64_t *pdata = static_cast<int64_t *>(out.data.data());
|
|
|
|
|
int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data());
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
EXPECT_EQ(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case PaddleDType::FLOAT32: {
|
|
|
|
|
float *pdata = static_cast<float *>(out.data.data());
|
|
|
|
|
float *pdata_ref = static_cast<float *>(ref_out.data.data());
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
CheckError(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case PaddleDType::INT32: {
|
|
|
|
|
int32_t *pdata = static_cast<int32_t *>(out.data.data());
|
|
|
|
|
int32_t *pdata_ref = static_cast<int32_t *>(ref_out.data.data());
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
EXPECT_EQ(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case PaddleDType::UINT8: {
|
|
|
|
|
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
|
|
|
|
|
uint8_t *pdata_ref = static_cast<uint8_t *>(ref_out.data.data());
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
EXPECT_EQ(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
#define COMPARE(paddle_type, type, func) \
|
|
|
|
|
case paddle_type: { \
|
|
|
|
|
type *pdata = static_cast<type *>(out.data.data()); \
|
|
|
|
|
type *pdata_ref = static_cast<type *>(ref_out.data.data()); \
|
|
|
|
|
for (size_t j = 0; j < size; ++j) { \
|
|
|
|
|
func(pdata_ref[j], pdata[j]); \
|
|
|
|
|
} \
|
|
|
|
|
break; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch (out.dtype) {
|
|
|
|
|
COMPARE(PaddleDType::INT64, int64_t, EXPECT_EQ);
|
|
|
|
|
COMPARE(PaddleDType::FLOAT32, float, CheckError);
|
|
|
|
|
COMPARE(PaddleDType::INT32, int32_t, EXPECT_EQ);
|
|
|
|
|
COMPARE(PaddleDType::UINT8, uint8_t, EXPECT_EQ);
|
|
|
|
|
COMPARE(PaddleDType::INT8, int8_t, EXPECT_EQ);
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"VarMessageToVarType: Unsupported dtype %d",
|
|
|
|
|
static_cast<int>(out.dtype)));
|
|
|
|
|
}
|
|
|
|
|
#undef COMPARE
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -306,44 +295,30 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
|
|
|
|
|
EXPECT_GT(size, 0UL);
|
|
|
|
|
int ref_size = 0; // this is the number of elements not memory size
|
|
|
|
|
PaddlePlace place;
|
|
|
|
|
switch (out.dtype) {
|
|
|
|
|
case PaddleDType::INT64: {
|
|
|
|
|
int64_t *pdata = static_cast<int64_t *>(out.data.data());
|
|
|
|
|
int64_t *pdata_ref = ref_out.data<int64_t>(&place, &ref_size);
|
|
|
|
|
EXPECT_EQ(size, static_cast<size_t>(ref_size));
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
EXPECT_EQ(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case PaddleDType::FLOAT32: {
|
|
|
|
|
float *pdata = static_cast<float *>(out.data.data());
|
|
|
|
|
float *pdata_ref = ref_out.data<float>(&place, &ref_size);
|
|
|
|
|
EXPECT_EQ(size, static_cast<size_t>(ref_size));
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
CheckError(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case PaddleDType::INT32: {
|
|
|
|
|
int32_t *pdata = static_cast<int32_t *>(out.data.data());
|
|
|
|
|
int32_t *pdata_ref = ref_out.data<int32_t>(&place, &ref_size);
|
|
|
|
|
EXPECT_EQ(size, static_cast<size_t>(ref_size));
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
EXPECT_EQ(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case PaddleDType::UINT8: {
|
|
|
|
|
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
|
|
|
|
|
uint8_t *pdata_ref = ref_out.data<uint8_t>(&place, &ref_size);
|
|
|
|
|
EXPECT_EQ(size, static_cast<size_t>(ref_size));
|
|
|
|
|
for (size_t j = 0; j < size; ++j) {
|
|
|
|
|
EXPECT_EQ(pdata_ref[j], pdata[j]);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
#define COMPARE(paddle_type, type, func) \
|
|
|
|
|
case paddle_type: { \
|
|
|
|
|
type *pdata = static_cast<type *>(out.data.data()); \
|
|
|
|
|
type *pdata_ref = ref_out.data<type>(&place, &ref_size); \
|
|
|
|
|
EXPECT_EQ(size, static_cast<size_t>(ref_size)); \
|
|
|
|
|
for (size_t j = 0; j < size; ++j) { \
|
|
|
|
|
func(pdata_ref[j], pdata[j]); \
|
|
|
|
|
} \
|
|
|
|
|
break; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch (out.dtype) {
|
|
|
|
|
COMPARE(PaddleDType::INT64, int64_t, EXPECT_EQ);
|
|
|
|
|
COMPARE(PaddleDType::FLOAT32, float, CheckError);
|
|
|
|
|
COMPARE(PaddleDType::INT32, int32_t, EXPECT_EQ);
|
|
|
|
|
COMPARE(PaddleDType::UINT8, uint8_t, EXPECT_EQ);
|
|
|
|
|
COMPARE(PaddleDType::INT8, int8_t, EXPECT_EQ);
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"VarMessageToVarType: Unsupported dtype %d",
|
|
|
|
|
static_cast<int>(out.dtype)));
|
|
|
|
|
}
|
|
|
|
|
#undef COMPARE
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|