!6046 check column name for bucket_batch_by_length

Merge pull request !6046 from yanghaitao/yht_bucket_batch_by_length
pull/6046/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit cd988cb694

@ -224,5 +224,19 @@ Status BucketBatchByLengthOp::Reset() {
return Status::OK();
}
// Computing the assignment of the column name map and check compute input columns.
Status BucketBatchByLengthOp::ComputeColMap() {
RETURN_IF_NOT_OK(DatasetOp::ComputeColMap());
for (const auto &inCol : length_dependent_columns_) {
bool found = column_name_id_map_.find(inCol) != column_name_id_map_.end() ? true : false;
if (!found) {
std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -135,6 +135,8 @@ class BucketBatchByLengthOp : public PipelineOp {
Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size);
Status ComputeColMap() override;
std::vector<std::string> length_dependent_columns_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;

Loading…
Cancel
Save