|
|
|
@ -35,7 +35,7 @@ TEST(Gather, GatherData) {
|
|
|
|
|
p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace());
|
|
|
|
|
p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < 12; ++i) p_src[i] = i;
|
|
|
|
|
for (int i = 0; i < 12; ++i) p_src[i] = i;
|
|
|
|
|
p_index[0] = 1;
|
|
|
|
|
p_index[1] = 0;
|
|
|
|
|
|
|
|
|
@ -43,6 +43,6 @@ TEST(Gather, GatherData) {
|
|
|
|
|
|
|
|
|
|
Gather<int>(CPUPlace(), src, index, output);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
|
|
|
|
|
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
|
|
|
|
|
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
|
|
|
|
|
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
|
|
|
|
|
}
|
|
|
|
|