parent
a4048e192c
commit
57f3732ac3
@ -0,0 +1,98 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "dataset/engine/opt/pre/global_shuffle.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status GlobalShufflePass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
std::vector<std::shared_ptr<TFReaderOp>> tf_readers;
|
||||
std::vector<std::shared_ptr<TextFileOp>> text_files;
|
||||
std::vector<std::shared_ptr<ClueOp>> clues;
|
||||
|
||||
// Pass 1, search for all sources which requires global shuffle
|
||||
for (auto &op : *tree) {
|
||||
if (auto ptr = std::dynamic_pointer_cast<TFReaderOp>(op.shared_from_this())) {
|
||||
if (ptr->RequireGlobalShuffle()) {
|
||||
tf_readers.push_back(ptr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (auto ptr = std::dynamic_pointer_cast<TextFileOp>(op.shared_from_this())) {
|
||||
if (ptr->RequireGlobalShuffle()) {
|
||||
text_files.push_back(ptr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (auto ptr = std::dynamic_pointer_cast<ClueOp>(op.shared_from_this())) {
|
||||
if (ptr->RequireGlobalShuffle()) {
|
||||
clues.push_back(ptr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2, insert shuffle nodes
|
||||
// The following blocks can be implemented with template if we unify the CountTotalRows across all source nodes .
|
||||
for (auto node : tf_readers) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
int64_t total_rows = 0;
|
||||
TFReaderOp::CountTotalRows(&total_rows, node->FileNames(), 8, true);
|
||||
int32_t avg_file_size = total_rows / (node->FileNames().size());
|
||||
builder->SetShuffleSize(std::max(avg_file_size * 4, 10000));
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
|
||||
for (auto node : text_files) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
int64_t total_rows = 0;
|
||||
TextFileOp::CountAllFileRows(node->FileNames(), &total_rows);
|
||||
int32_t avg_file_size = total_rows / (node->FileNames().size());
|
||||
builder->SetShuffleSize(std::max(avg_file_size * 4, 10000));
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
|
||||
for (auto node : clues) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
int64_t total_rows = 0;
|
||||
ClueOp::CountAllFileRows(node->FileNames(), &total_rows);
|
||||
int32_t avg_file_size = total_rows / (node->FileNames().size());
|
||||
builder->SetShuffleSize(std::max(avg_file_size * 4, 10000));
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Global Shuffle Pass will insert ShuffleOp when the leaf nodes requires global shuffle.
|
||||
// Example:
|
||||
// Input Tree: TFReader(GLOBAL_SHUFFLE) -> Batch
|
||||
// Output Tree: TFReader -> Shuffle -> Batch
|
||||
class GlobalShufflePass : public TreePass {
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
|
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "dataset/engine/opt/pre/map_column_reorder.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/map_op.h"
|
||||
#include "dataset/engine/datasetops/project_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status MapColumnReorder::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
std::vector<std::shared_ptr<MapOp>> to_process;
|
||||
|
||||
// Pass 1, search for all MapOp with column orders
|
||||
for (auto &op : *tree) {
|
||||
if (auto mapOp = std::dynamic_pointer_cast<MapOp>(op.shared_from_this())) {
|
||||
if (mapOp->ColumnsOrder().size() != 0) {
|
||||
to_process.push_back(mapOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2, insert nodes for all MapOp
|
||||
for (auto node : to_process) {
|
||||
std::shared_ptr<ProjectOp::Builder> builder = std::make_shared<ProjectOp::Builder>(node->ColumnsOrder());
|
||||
std::shared_ptr<ProjectOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(op));
|
||||
RETURN_IF_NOT_OK(node->InsertAsParent(op));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Map Column Recorder Pass will insert ProjectOp when MapOp requires a full output columns reorder.
|
||||
// Example:
|
||||
// Input Tree: TFReader -> MapOp(with col_order) -> Batch
|
||||
// Output Tree: TFReader -> MapOp -> ProjectOp(col_order) -> Batch
|
||||
class MapColumnReorder : public TreePass {
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
|
@ -0,0 +1,90 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def test_map_reorder_pass_0():
|
||||
def generator_mc(maxid=1):
|
||||
for _ in range(maxid):
|
||||
yield (np.array([0]), np.array([1]))
|
||||
|
||||
# Generator -> Map
|
||||
data0 = ds.GeneratorDataset(generator_mc, ["col0", "col1"])
|
||||
|
||||
data0 = data0.map(input_columns="col0", output_columns="out", columns_order=["col1", "out"],
|
||||
operations=(lambda x: x))
|
||||
|
||||
for item in data0.create_tuple_iterator(): # each data is a dictionary
|
||||
assert item == [np.array(1), np.array(0)]
|
||||
|
||||
|
||||
def test_map_reorder_pass_1():
|
||||
def generator_mc(maxid=1):
|
||||
for _ in range(maxid):
|
||||
yield (np.array([0]), np.array([1]), np.array([2]))
|
||||
|
||||
# Three map and zip
|
||||
data0 = ds.GeneratorDataset(generator_mc, ["a0", "a1", "a2"])
|
||||
data0 = data0.map(input_columns="a0", columns_order=["a2", "a1", "a0"], operations=(lambda x: x))
|
||||
data1 = ds.GeneratorDataset(generator_mc, ["b0", "b1", "b2"])
|
||||
data1 = data1.map(input_columns="b0", columns_order=["b1", "b2", "b0"], operations=(lambda x: x))
|
||||
data2 = ds.zip((data0, data1))
|
||||
data2 = data2.map(input_columns="a0", columns_order=["b2", "a2", "b1", "a1", "b0", "a0"], operations=(lambda x: x))
|
||||
|
||||
for item in data2.create_tuple_iterator():
|
||||
assert item == [np.array(2), np.array(2), np.array(1), np.array(1), np.array(0), np.array(0)]
|
||||
|
||||
|
||||
def test_global_shuffle_pass():
|
||||
|
||||
FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
|
||||
ds.config.set_seed(1)
|
||||
data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
|
||||
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1, data2):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
assert np.array_equal(t1, t2)
|
||||
|
||||
ds.config.set_seed(1)
|
||||
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
|
||||
data1 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
|
||||
data2 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1, data2):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
assert np.array_equal(t1, t2)
|
||||
|
||||
ds.config.set_seed(1)
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
data1 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.GLOBAL)
|
||||
data2 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1, data2):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
assert np.array_equal(t1, t2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_map_reorder_pass_0()
|
||||
test_map_reorder_pass_1()
|
||||
test_global_shuffle_pass()
|
Loading…
Reference in new issue