fix SplitLodTensor when batch_size = 0, test=develop (#19866)

expand_as_op_1
Leo Chen 5 years ago committed by Zeng Jinle
parent b125e327aa
commit 578a2f5da3

@ -283,6 +283,21 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
std::vector<LoDTensor> results;
results.reserve(result_size);
// if result_size(batch_size) is 0, just return #places.size() copys of empty
// tensors.
if (result_size == 0) {
for (size_t i = 0; i < places.size(); ++i) {
LoDTensor dst;
dst.Resize(dims());
dst.mutable_data(places[i], type());
if (!lod().empty()) {
dst.set_lod(lod());
}
results.emplace_back(dst);
}
return results;
}
int step_width = static_cast<int>(batch_size / result_size);
for (size_t i = 0; i < result_size; ++i) {
int begin = static_cast<int>(i * step_width);

@ -155,6 +155,26 @@ TEST(LoD, SplitLoDTensor) {
EXPECT_EQ(lods[1].lod(), lod1);
}
TEST(LoD, SplitLoDTensorWithZeroBatchSize) {
LoD lod;
lod.push_back(std::vector<size_t>({0}));
platform::CPUPlace place;
LoDTensor lod_tensor;
lod_tensor.Resize({0, 5});
lod_tensor.mutable_data<float>(place);
lod_tensor.set_lod(lod);
std::vector<platform::Place> places{platform::CPUPlace(),
platform::CPUPlace()};
LoD lod_res;
lod_res.push_back(std::vector<size_t>({0}));
auto lods = lod_tensor.SplitLoDTensor(places);
EXPECT_EQ(lods[0].lod(), lod_res);
EXPECT_EQ(lods[1].lod(), lod_res);
}
TEST(LoD, MergeLoDTensor) {
LoD lod;
lod.push_back(std::vector<size_t>({0, 2, 4, 5, 6}));

Loading…
Cancel
Save