!182 Tuning mindrecord writer performance
Merge pull request !182 from jonathan_yan/write_apull/182/MERGE
commit
c73301bad9
@ -0,0 +1,46 @@
|
|||||||
|
# MindRecord generating guidelines
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [MindRecord generating guidelines](#mindrecord-generating-guidelines)
|
||||||
|
- [Create work space](#create-work-space)
|
||||||
|
- [Implement data generator](#implement-data-generator)
|
||||||
|
- [Run data generator](#run-data-generator)
|
||||||
|
|
||||||
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
## Create work space
|
||||||
|
|
||||||
|
Assume the dataset name is 'xyz'
|
||||||
|
* Create work space from template
|
||||||
|
```shell
|
||||||
|
cd ${your_mindspore_home}/example/convert_to_mindrecord
|
||||||
|
cp -r template xyz
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implement data generator
|
||||||
|
|
||||||
|
Edit dictionary data generator
|
||||||
|
* Edit file
|
||||||
|
```shell
|
||||||
|
cd ${your_mindspore_home}/example/convert_to_mindrecord
|
||||||
|
vi xyz/mr_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented
|
||||||
|
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
|
||||||
|
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()
|
||||||
|
|
||||||
|
|
||||||
|
Tricky for parallel run
|
||||||
|
- For imagenet, one directory can be a task.
|
||||||
|
- For TFRecord with multiple files, each file can be a task.
|
||||||
|
- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K)
|
||||||
|
|
||||||
|
|
||||||
|
## Run data generator
|
||||||
|
* run python script
|
||||||
|
```shell
|
||||||
|
cd ${your_mindspore_home}/example/convert_to_mindrecord
|
||||||
|
python writer.py --mindrecord_script imagenet [...]
|
||||||
|
```
|
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
User-defined API for MindRecord writer.
|
||||||
|
Two API must be implemented,
|
||||||
|
1. mindrecord_task_number()
|
||||||
|
# Return number of parallel tasks. return 1 if no parallel
|
||||||
|
2. mindrecord_dict_data(task_id)
|
||||||
|
# Yield data for one task
|
||||||
|
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
######## mindrecord_schema begin ##########
|
||||||
|
mindrecord_schema = {"label": {"type": "int64"},
|
||||||
|
"data": {"type": "bytes"},
|
||||||
|
"file_name": {"type": "string"}}
|
||||||
|
######## mindrecord_schema end ##########
|
||||||
|
|
||||||
|
######## Frozen code begin ##########
|
||||||
|
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle:
|
||||||
|
ARG_LIST = pickle.load(mindrecord_argument_file_handle)
|
||||||
|
######## Frozen code end ##########
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Mind record imagenet example')
|
||||||
|
parser.add_argument('--label_file', type=str, default="", help='label file')
|
||||||
|
parser.add_argument('--image_dir', type=str, default="", help='images directory')
|
||||||
|
|
||||||
|
######## Frozen code begin ##########
|
||||||
|
args = parser.parse_args(ARG_LIST)
|
||||||
|
print(args)
|
||||||
|
######## Frozen code end ##########
|
||||||
|
|
||||||
|
|
||||||
|
def _user_defined_private_func():
|
||||||
|
"""
|
||||||
|
Internal function for tasks list
|
||||||
|
|
||||||
|
Return:
|
||||||
|
tasks list
|
||||||
|
"""
|
||||||
|
if not os.path.exists(args.label_file):
|
||||||
|
raise IOError("map file {} not exists".format(args.label_file))
|
||||||
|
|
||||||
|
label_dict = {}
|
||||||
|
with open(args.label_file) as file_handle:
|
||||||
|
line = file_handle.readline()
|
||||||
|
while line:
|
||||||
|
labels = line.split(" ")
|
||||||
|
label_dict[labels[1]] = labels[0]
|
||||||
|
line = file_handle.readline()
|
||||||
|
# get all the dir which are n02087046, n02094114, n02109525
|
||||||
|
dir_paths = {}
|
||||||
|
for item in label_dict:
|
||||||
|
real_path = os.path.join(args.image_dir, label_dict[item])
|
||||||
|
if not os.path.isdir(real_path):
|
||||||
|
print("{} dir is not exist".format(real_path))
|
||||||
|
continue
|
||||||
|
dir_paths[item] = real_path
|
||||||
|
|
||||||
|
if not dir_paths:
|
||||||
|
print("not valid image dir in {}".format(args.image_dir))
|
||||||
|
return {}, {}
|
||||||
|
|
||||||
|
dir_list = []
|
||||||
|
for label in dir_paths:
|
||||||
|
dir_list.append(label)
|
||||||
|
return dir_list, dir_paths
|
||||||
|
|
||||||
|
|
||||||
|
dir_list_global, dir_paths_global = _user_defined_private_func()
|
||||||
|
|
||||||
|
def mindrecord_task_number():
|
||||||
|
"""
|
||||||
|
Get task size.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
number of tasks
|
||||||
|
"""
|
||||||
|
return len(dir_list_global)
|
||||||
|
|
||||||
|
|
||||||
|
def mindrecord_dict_data(task_id):
|
||||||
|
"""
|
||||||
|
Get data dict.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
data (dict): data row which is dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# get the filename, label and image binary as a dict
|
||||||
|
label = dir_list_global[task_id]
|
||||||
|
for item in os.listdir(dir_paths_global[label]):
|
||||||
|
file_name = os.path.join(dir_paths_global[label], item)
|
||||||
|
if not item.endswith("JPEG") and not item.endswith(
|
||||||
|
"jpg") and not item.endswith("jpeg"):
|
||||||
|
print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name))
|
||||||
|
continue
|
||||||
|
data = {}
|
||||||
|
data["file_name"] = str(file_name)
|
||||||
|
data["label"] = int(label)
|
||||||
|
|
||||||
|
# get the image data
|
||||||
|
image_file = open(file_name, "rb")
|
||||||
|
image_bytes = image_file.read()
|
||||||
|
image_file.close()
|
||||||
|
data["data"] = image_bytes
|
||||||
|
yield data
|
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
rm /tmp/imagenet/mr/*
|
||||||
|
|
||||||
|
python writer.py --mindrecord_script imagenet \
|
||||||
|
--mindrecord_file "/tmp/imagenet/mr/m" \
|
||||||
|
--mindrecord_partitions 16 \
|
||||||
|
--label_file "/tmp/imagenet/label.txt" \
|
||||||
|
--image_dir "/tmp/imagenet/jpeg"
|
@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
rm /tmp/template/*
|
||||||
|
|
||||||
|
python writer.py --mindrecord_script template \
|
||||||
|
--mindrecord_file "/tmp/template/m" \
|
||||||
|
--mindrecord_partitions 4
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
User-defined API for MindRecord writer.
|
||||||
|
Two API must be implemented,
|
||||||
|
1. mindrecord_task_number()
|
||||||
|
# Return number of parallel tasks. return 1 if no parallel
|
||||||
|
2. mindrecord_dict_data(task_id)
|
||||||
|
# Yield data for one task
|
||||||
|
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
# ## Parse argument
|
||||||
|
|
||||||
|
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line
|
||||||
|
ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line
|
||||||
|
parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line
|
||||||
|
|
||||||
|
# ## Your arguments below
|
||||||
|
# parser.add_argument(...)
|
||||||
|
|
||||||
|
args = parser.parse_args(ARG_LIST) # Do NOT change this line
|
||||||
|
print(args) # Do NOT change this line
|
||||||
|
|
||||||
|
|
||||||
|
# ## Default mindrecord vars. Comment them unless default value has to be changed.
|
||||||
|
# mindrecord_index_fields = ['label']
|
||||||
|
# mindrecord_header_size = 1 << 24
|
||||||
|
# mindrecord_page_size = 1 << 25
|
||||||
|
|
||||||
|
|
||||||
|
# define global vars here if necessary
|
||||||
|
|
||||||
|
|
||||||
|
# ####### Your code below ##########
|
||||||
|
mindrecord_schema = {"label": {"type": "int32"}}
|
||||||
|
|
||||||
|
def mindrecord_task_number():
|
||||||
|
"""
|
||||||
|
Get task size.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
number of tasks
|
||||||
|
"""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def mindrecord_dict_data(task_id):
|
||||||
|
"""
|
||||||
|
Get data dict.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
data (dict): data row which is dict.
|
||||||
|
"""
|
||||||
|
print("task is {}".format(task_id))
|
||||||
|
for i in range(256):
|
||||||
|
data = {}
|
||||||
|
data['label'] = i
|
||||||
|
yield data
|
@ -0,0 +1,149 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
######################## write mindrecord example ########################
|
||||||
|
Write mindrecord by data dictionary:
|
||||||
|
python writer.py --mindrecord_script /YourScriptPath ...
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from importlib import import_module
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
|
|
||||||
|
def _exec_task(task_id, parallel_writer=True):
|
||||||
|
"""
|
||||||
|
Execute task with specified task id
|
||||||
|
"""
|
||||||
|
print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
|
||||||
|
imagenet_iter = mindrecord_dict_data(task_id)
|
||||||
|
batch_size = 2048
|
||||||
|
transform_count = 0
|
||||||
|
while True:
|
||||||
|
data_list = []
|
||||||
|
try:
|
||||||
|
for _ in range(batch_size):
|
||||||
|
data_list.append(imagenet_iter.__next__())
|
||||||
|
transform_count += 1
|
||||||
|
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
|
||||||
|
print("transformed {} record...".format(transform_count))
|
||||||
|
except StopIteration:
|
||||||
|
if data_list:
|
||||||
|
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
|
||||||
|
print("transformed {} record...".format(transform_count))
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Mind record writer')
|
||||||
|
parser.add_argument('--mindrecord_script', type=str, default="template",
|
||||||
|
help='path where script is saved')
|
||||||
|
|
||||||
|
parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord",
|
||||||
|
help='written file name prefix')
|
||||||
|
|
||||||
|
parser.add_argument('--mindrecord_partitions', type=int, default=1,
|
||||||
|
help='number of written files')
|
||||||
|
|
||||||
|
parser.add_argument('--mindrecord_workers', type=int, default=8,
|
||||||
|
help='number of parallel workers')
|
||||||
|
|
||||||
|
args = parser.parse_known_args()
|
||||||
|
|
||||||
|
args, other_args = parser.parse_known_args()
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
print(other_args)
|
||||||
|
|
||||||
|
with open('mr_argument.pickle', 'wb') as file_handle:
|
||||||
|
pickle.dump(other_args, file_handle)
|
||||||
|
|
||||||
|
try:
|
||||||
|
mr_api = import_module(args.mindrecord_script + '.mr_api')
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
|
||||||
|
|
||||||
|
num_tasks = mr_api.mindrecord_task_number()
|
||||||
|
|
||||||
|
print("Write mindrecord ...")
|
||||||
|
|
||||||
|
mindrecord_dict_data = mr_api.mindrecord_dict_data
|
||||||
|
|
||||||
|
# get number of files
|
||||||
|
writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# set the header size
|
||||||
|
try:
|
||||||
|
header_size = mr_api.mindrecord_header_size
|
||||||
|
writer.set_header_size(header_size)
|
||||||
|
except AttributeError:
|
||||||
|
print("Default header size: {}".format(1 << 24))
|
||||||
|
|
||||||
|
# set the page size
|
||||||
|
try:
|
||||||
|
page_size = mr_api.mindrecord_page_size
|
||||||
|
writer.set_page_size(page_size)
|
||||||
|
except AttributeError:
|
||||||
|
print("Default page size: {}".format(1 << 25))
|
||||||
|
|
||||||
|
# get schema
|
||||||
|
try:
|
||||||
|
mindrecord_schema = mr_api.mindrecord_schema
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError("mindrecord_schema is not defined in mr_api.py.")
|
||||||
|
|
||||||
|
# create the schema
|
||||||
|
writer.add_schema(mindrecord_schema, "mindrecord_schema")
|
||||||
|
|
||||||
|
# add the index
|
||||||
|
try:
|
||||||
|
index_fields = mr_api.mindrecord_index_fields
|
||||||
|
writer.add_index(index_fields)
|
||||||
|
except AttributeError:
|
||||||
|
print("Default index fields: all simple fields are indexes.")
|
||||||
|
|
||||||
|
writer.open_and_set_header()
|
||||||
|
|
||||||
|
task_list = list(range(num_tasks))
|
||||||
|
|
||||||
|
# set number of workers
|
||||||
|
num_workers = args.mindrecord_workers
|
||||||
|
|
||||||
|
if num_tasks < 1:
|
||||||
|
num_tasks = 1
|
||||||
|
|
||||||
|
if num_workers > num_tasks:
|
||||||
|
num_workers = num_tasks
|
||||||
|
|
||||||
|
if num_tasks > 1:
|
||||||
|
with Pool(num_workers) as p:
|
||||||
|
p.map(_exec_task, task_list)
|
||||||
|
else:
|
||||||
|
_exec_task(0, False)
|
||||||
|
|
||||||
|
ret = writer.commit()
|
||||||
|
|
||||||
|
os.remove("{}".format("mr_argument.pickle"))
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("--------------------------------------------")
|
||||||
|
print("END. Total time: {}".format(end_time - start_time))
|
||||||
|
print("--------------------------------------------")
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue