|
|
|
@ -73,7 +73,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
|
|
|
|
|
for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) {
|
|
|
|
|
if (size_device_vec[src_idx][0] <= expected_device_size) {
|
|
|
|
|
++src_idx;
|
|
|
|
|
PADDLE_ENFORCE_LT(src_idx, device_num - empty_num);
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
src_idx, device_num - empty_num,
|
|
|
|
|
"In current srategy an empty tensor should not be copy source.");
|
|
|
|
|
}
|
|
|
|
|
size_device_vec[src_idx][0] -= expected_device_size;
|
|
|
|
|
size_device_vec[dst_idx][0] += expected_device_size;
|
|
|
|
@ -113,7 +115,9 @@ void DataBalanceOpHandle::RunImpl() {
|
|
|
|
|
if (data_idx == 0) {
|
|
|
|
|
device_sizes.emplace_back(ins_size);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins_size, device_sizes.at(place_idx));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ins_size, device_sizes.at(place_idx),
|
|
|
|
|
"All data on the same device shall have the same batch size.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
const auto &balance_plan = GetBalancePlan(device_sizes);
|
|
|
|
|