!207 Fix potential overflow in batch by changing int32 to int64

Merge pull request !207 from ZiruiWu/batch_info_overflow
pull/207/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d9cd681c30

@ -406,7 +406,7 @@ void bindSamplerOps(py::module *m) {
void bindInfoObjects(py::module *m) { void bindInfoObjects(py::module *m) {
(void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo") (void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
.def(py::init<int32_t, int32_t, int32_t>()) .def(py::init<int64_t, int64_t, int64_t>())
.def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
} }

@ -57,7 +57,7 @@ BatchOp::BatchOp(int32_t batch_size, bool drop, int32_t op_queue_size, int32_t n
Status BatchOp::operator()() { Status BatchOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
int32_t epoch_num = 0, batch_num = 0, cnt = 0; int64_t epoch_num = 0, batch_num = 0, cnt = 0;
TensorRow new_row; TensorRow new_row;
std::unique_ptr<TensorQTable> table = std::make_unique<TensorQTable>(); std::unique_ptr<TensorQTable> table = std::make_unique<TensorQTable>();
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);

@ -124,17 +124,17 @@ class BatchOp : public ParallelOp {
// This struct is used for both internal control and python callback. // This struct is used for both internal control and python callback.
// This struct is bound to python with read-only access. // This struct is bound to python with read-only access.
struct CBatchInfo { struct CBatchInfo {
CBatchInfo(int32_t ep, int32_t bat, int32_t cur, batchCtrl ctrl) CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl)
: epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {} : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {}
CBatchInfo(int32_t ep, int32_t bat, int32_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {} CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {}
CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {} CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {}
explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {} explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {}
int32_t epoch_num_; // i-th epoch. i starts from 0 int64_t epoch_num_; // i-th epoch. i starts from 0
int32_t batch_num_; // i-th batch since the start of current epoch. i starts from 0 int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0
int32_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0 int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0
batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3 batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3
const int32_t get_batch_num() const { return batch_num_; } const int64_t get_batch_num() const { return batch_num_; }
const int32_t get_epoch_num() const { return epoch_num_; } const int64_t get_epoch_num() const { return epoch_num_; }
}; };
// BatchOp constructor // BatchOp constructor

@ -201,8 +201,8 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
Status Reset() override; Status Reset() override;
bool decode_; bool decode_;
uint64_t row_cnt_; int64_t row_cnt_;
uint64_t buf_cnt_; int64_t buf_cnt_;
int64_t num_rows_; int64_t num_rows_;
int64_t num_samples_; int64_t num_samples_;
std::string folder_path_; std::string folder_path_;

Loading…
Cancel
Save