|
|
|
@ -13,16 +13,16 @@
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
|
|
|
|
|
|
|
|
|
#include <iomanip>
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "utils/ms_utils.h"
|
|
|
|
|
#include "minddata/dataset/core/config_manager.h"
|
|
|
|
|
#include "minddata/dataset/engine/data_buffer.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/opt/pass.h"
|
|
|
|
|
#include "minddata/dataset/engine/db_connector.h"
|
|
|
|
|
#include "minddata/dataset/engine/execution_tree.h"
|
|
|
|
|
#include "minddata/dataset/engine/opt/pass.h"
|
|
|
|
|
#include "utils/ms_utils.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
@ -101,7 +101,7 @@ Status ConcatOp::operator()() {
|
|
|
|
|
// 3. Put the data into output_connector
|
|
|
|
|
if (!children_flag_and_nums_.empty()) is_not_mappable = children_flag_and_nums_[i].first;
|
|
|
|
|
while (!buf->eoe() && !buf->eof()) {
|
|
|
|
|
// if dataset is no mappable or generator dataset which source is yeild(cannot get the number of samples in
|
|
|
|
|
// if dataset is not mappable or generator dataset which source is yield, cannot get the number of samples in
|
|
|
|
|
// python layer), we use filtering to get data
|
|
|
|
|
if (sample_number % num_shard == shard_index && (is_not_mappable || !children_flag_and_nums_[i].second)) {
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
|
|
|
@ -125,7 +125,7 @@ Status ConcatOp::operator()() {
|
|
|
|
|
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if dataset is mappable,We do't use filtering to pick data.
|
|
|
|
|
// if dataset is mappable,We don't use filtering to pick data.
|
|
|
|
|
// so sample_number plus the length of the entire dataset
|
|
|
|
|
if (!is_not_mappable && children_flag_and_nums_[i].second) {
|
|
|
|
|
sample_number += children_flag_and_nums_[i].second;
|
|
|
|
@ -142,7 +142,7 @@ Status ConcatOp::operator()() {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
|
|
|
|
|
"Something went wrong, eof count does not match the number of children.");
|
|
|
|
|
// 5. Add eof buffer in the end manually
|
|
|
|
|
MS_LOG(DEBUG) << "Add the eof buffer manualy in the end.";
|
|
|
|
|
MS_LOG(DEBUG) << "Add the eof buffer manually in the end.";
|
|
|
|
|
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
|
|
|
|
return Status::OK();
|
|
|
|
@ -150,7 +150,7 @@ Status ConcatOp::operator()() {
|
|
|
|
|
|
|
|
|
|
Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
|
|
|
|
|
TensorRow new_row;
|
|
|
|
|
buf->GetRow(0, &new_row);
|
|
|
|
|
RETURN_IF_NOT_OK(buf->GetRow(0, &new_row));
|
|
|
|
|
|
|
|
|
|
if (id == 0) {
|
|
|
|
|
// Obtain the data type and data rank in child[0]
|
|
|
|
|