!182 Tuning mindrecord writer performance
	
		
	
				
					
				
			Merge pull request !182 from jonathan_yan/write_apull/182/MERGE
						commit
						c73301bad9
					
				@ -0,0 +1,46 @@
 | 
				
			||||
# MindRecord generating guidelines
 | 
				
			||||
 | 
				
			||||
<!-- TOC -->
 | 
				
			||||
 | 
				
			||||
- [MindRecord generating guidelines](#mindrecord-generating-guidelines)
 | 
				
			||||
    - [Create work space](#create-work-space)
 | 
				
			||||
    - [Implement data generator](#implement-data-generator)
 | 
				
			||||
    - [Run data generator](#run-data-generator)
 | 
				
			||||
 | 
				
			||||
<!-- /TOC -->
 | 
				
			||||
 | 
				
			||||
## Create work space
 | 
				
			||||
 | 
				
			||||
Assume the dataset name is 'xyz'
 | 
				
			||||
* Create work space from template
 | 
				
			||||
    ```shell
 | 
				
			||||
    cd ${your_mindspore_home}/example/convert_to_mindrecord
 | 
				
			||||
    cp -r template xyz
 | 
				
			||||
    ```
 | 
				
			||||
 | 
				
			||||
## Implement data generator 
 | 
				
			||||
 | 
				
			||||
Edit dictionary data generator  
 | 
				
			||||
* Edit file 
 | 
				
			||||
    ```shell
 | 
				
			||||
    cd ${your_mindspore_home}/example/convert_to_mindrecord
 | 
				
			||||
    vi xyz/mr_api.py
 | 
				
			||||
    ```
 | 
				
			||||
 | 
				
			||||
 Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented
 | 
				
			||||
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
 | 
				
			||||
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
Tricky for parallel run
 | 
				
			||||
- For imagenet, one directory can be a task.
 | 
				
			||||
- For TFRecord with multiple files, each file can be a task.
 | 
				
			||||
- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) 
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
## Run data generator 
 | 
				
			||||
* run python script 
 | 
				
			||||
    ```shell
 | 
				
			||||
    cd ${your_mindspore_home}/example/convert_to_mindrecord
 | 
				
			||||
    python writer.py --mindrecord_script imagenet [...]
 | 
				
			||||
    ```
 | 
				
			||||
@ -0,0 +1,122 @@
 | 
				
			||||
# Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
# ==============================================================================
 | 
				
			||||
"""
 | 
				
			||||
User-defined API for MindRecord writer.
 | 
				
			||||
Two API must be implemented,
 | 
				
			||||
  1. mindrecord_task_number()
 | 
				
			||||
       # Return number of parallel tasks. return 1 if no parallel
 | 
				
			||||
  2. mindrecord_dict_data(task_id)
 | 
				
			||||
       # Yield data for one task
 | 
				
			||||
       # task_id is 0..N-1, if N is return value of mindrecord_task_number()
 | 
				
			||||
"""
 | 
				
			||||
import argparse
 | 
				
			||||
import os
 | 
				
			||||
import pickle
 | 
				
			||||
 | 
				
			||||
######## mindrecord_schema begin ##########
 | 
				
			||||
mindrecord_schema = {"label": {"type": "int64"},
 | 
				
			||||
                     "data": {"type": "bytes"},
 | 
				
			||||
                     "file_name": {"type": "string"}}
 | 
				
			||||
######## mindrecord_schema end ##########
 | 
				
			||||
 | 
				
			||||
######## Frozen code begin ##########
 | 
				
			||||
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle:
 | 
				
			||||
    ARG_LIST = pickle.load(mindrecord_argument_file_handle)
 | 
				
			||||
######## Frozen code end ##########
 | 
				
			||||
 | 
				
			||||
parser = argparse.ArgumentParser(description='Mind record imagenet example')
 | 
				
			||||
parser.add_argument('--label_file', type=str, default="", help='label file')
 | 
				
			||||
parser.add_argument('--image_dir', type=str, default="", help='images directory')
 | 
				
			||||
 | 
				
			||||
######## Frozen code begin ##########
 | 
				
			||||
args = parser.parse_args(ARG_LIST)
 | 
				
			||||
print(args)
 | 
				
			||||
######## Frozen code end ##########
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def _user_defined_private_func():
 | 
				
			||||
    """
 | 
				
			||||
    Internal function for tasks list
 | 
				
			||||
 | 
				
			||||
    Return:
 | 
				
			||||
       tasks list
 | 
				
			||||
    """
 | 
				
			||||
    if not os.path.exists(args.label_file):
 | 
				
			||||
        raise IOError("map file {} not exists".format(args.label_file))
 | 
				
			||||
 | 
				
			||||
    label_dict = {}
 | 
				
			||||
    with open(args.label_file) as file_handle:
 | 
				
			||||
        line = file_handle.readline()
 | 
				
			||||
        while line:
 | 
				
			||||
            labels = line.split(" ")
 | 
				
			||||
            label_dict[labels[1]] = labels[0]
 | 
				
			||||
            line = file_handle.readline()
 | 
				
			||||
    # get all the dir which are n02087046, n02094114, n02109525
 | 
				
			||||
    dir_paths = {}
 | 
				
			||||
    for item in label_dict:
 | 
				
			||||
        real_path = os.path.join(args.image_dir, label_dict[item])
 | 
				
			||||
        if not os.path.isdir(real_path):
 | 
				
			||||
            print("{} dir is not exist".format(real_path))
 | 
				
			||||
            continue
 | 
				
			||||
        dir_paths[item] = real_path
 | 
				
			||||
 | 
				
			||||
    if not dir_paths:
 | 
				
			||||
        print("not valid image dir in {}".format(args.image_dir))
 | 
				
			||||
        return {}, {}
 | 
				
			||||
 | 
				
			||||
    dir_list = []
 | 
				
			||||
    for label in dir_paths:
 | 
				
			||||
        dir_list.append(label)
 | 
				
			||||
    return dir_list, dir_paths
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
dir_list_global, dir_paths_global = _user_defined_private_func()
 | 
				
			||||
 | 
				
			||||
def mindrecord_task_number():
 | 
				
			||||
    """
 | 
				
			||||
    Get task size.
 | 
				
			||||
 | 
				
			||||
    Return:
 | 
				
			||||
       number of tasks
 | 
				
			||||
    """
 | 
				
			||||
    return len(dir_list_global)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def mindrecord_dict_data(task_id):
 | 
				
			||||
    """
 | 
				
			||||
    Get data dict.
 | 
				
			||||
 | 
				
			||||
    Yields:
 | 
				
			||||
        data (dict): data row which is dict.
 | 
				
			||||
    """
 | 
				
			||||
 | 
				
			||||
    # get the filename, label and image binary as a dict
 | 
				
			||||
    label = dir_list_global[task_id]
 | 
				
			||||
    for item in os.listdir(dir_paths_global[label]):
 | 
				
			||||
        file_name = os.path.join(dir_paths_global[label], item)
 | 
				
			||||
        if not item.endswith("JPEG") and not item.endswith(
 | 
				
			||||
                "jpg") and not item.endswith("jpeg"):
 | 
				
			||||
            print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name))
 | 
				
			||||
            continue
 | 
				
			||||
        data = {}
 | 
				
			||||
        data["file_name"] = str(file_name)
 | 
				
			||||
        data["label"] = int(label)
 | 
				
			||||
 | 
				
			||||
        # get the image data
 | 
				
			||||
        image_file = open(file_name, "rb")
 | 
				
			||||
        image_bytes = image_file.read()
 | 
				
			||||
        image_file.close()
 | 
				
			||||
        data["data"] = image_bytes
 | 
				
			||||
        yield data
 | 
				
			||||
@ -0,0 +1,8 @@
 | 
				
			||||
#!/bin/bash
 | 
				
			||||
rm /tmp/imagenet/mr/*
 | 
				
			||||
 | 
				
			||||
python writer.py --mindrecord_script imagenet \
 | 
				
			||||
--mindrecord_file "/tmp/imagenet/mr/m" \
 | 
				
			||||
--mindrecord_partitions 16 \
 | 
				
			||||
--label_file "/tmp/imagenet/label.txt" \
 | 
				
			||||
--image_dir "/tmp/imagenet/jpeg"
 | 
				
			||||
@ -0,0 +1,6 @@
 | 
				
			||||
#!/bin/bash
 | 
				
			||||
rm /tmp/template/*
 | 
				
			||||
 | 
				
			||||
python writer.py --mindrecord_script template \
 | 
				
			||||
--mindrecord_file "/tmp/template/m" \
 | 
				
			||||
--mindrecord_partitions 4
 | 
				
			||||
@ -0,0 +1,73 @@
 | 
				
			||||
# Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
# ==============================================================================
 | 
				
			||||
"""
 | 
				
			||||
User-defined API for MindRecord writer.
 | 
				
			||||
Two API must be implemented,
 | 
				
			||||
  1. mindrecord_task_number()
 | 
				
			||||
       # Return number of parallel tasks. return 1 if no parallel
 | 
				
			||||
  2. mindrecord_dict_data(task_id)
 | 
				
			||||
       # Yield data for one task
 | 
				
			||||
       # task_id is 0..N-1, if N is return value of mindrecord_task_number()
 | 
				
			||||
"""
 | 
				
			||||
import argparse
 | 
				
			||||
import pickle
 | 
				
			||||
 | 
				
			||||
# ## Parse argument
 | 
				
			||||
 | 
				
			||||
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle:  # Do NOT change this line
 | 
				
			||||
    ARG_LIST = pickle.load(mindrecord_argument_file_handle)                # Do NOT change this line
 | 
				
			||||
parser = argparse.ArgumentParser(description='Mind record api template')   # Do NOT change this line
 | 
				
			||||
 | 
				
			||||
# ## Your arguments below
 | 
				
			||||
# parser.add_argument(...)
 | 
				
			||||
 | 
				
			||||
args = parser.parse_args(ARG_LIST)  # Do NOT change this line
 | 
				
			||||
print(args)                         # Do NOT change this line
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
# ## Default mindrecord vars. Comment them unless default value has to be changed.
 | 
				
			||||
# mindrecord_index_fields = ['label']
 | 
				
			||||
# mindrecord_header_size = 1 << 24
 | 
				
			||||
# mindrecord_page_size = 1 << 25
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
# define global vars here if necessary
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
# ####### Your code below ##########
 | 
				
			||||
mindrecord_schema = {"label": {"type": "int32"}}
 | 
				
			||||
 | 
				
			||||
def mindrecord_task_number():
 | 
				
			||||
    """
 | 
				
			||||
    Get task size.
 | 
				
			||||
 | 
				
			||||
    Return:
 | 
				
			||||
       number of tasks
 | 
				
			||||
    """
 | 
				
			||||
    return 1
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def mindrecord_dict_data(task_id):
 | 
				
			||||
    """
 | 
				
			||||
    Get data dict.
 | 
				
			||||
 | 
				
			||||
    Yields:
 | 
				
			||||
        data (dict): data row which is dict.
 | 
				
			||||
    """
 | 
				
			||||
    print("task is {}".format(task_id))
 | 
				
			||||
    for i in range(256):
 | 
				
			||||
        data = {}
 | 
				
			||||
        data['label'] = i
 | 
				
			||||
        yield data
 | 
				
			||||
@ -0,0 +1,149 @@
 | 
				
			||||
# Copyright 2020 Huawei Technologies Co., Ltd
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
# ==============================================================================
 | 
				
			||||
"""
 | 
				
			||||
######################## write mindrecord example ########################
 | 
				
			||||
Write mindrecord by data dictionary:
 | 
				
			||||
python writer.py --mindrecord_script /YourScriptPath ...
 | 
				
			||||
"""
 | 
				
			||||
import argparse
 | 
				
			||||
import os
 | 
				
			||||
import pickle
 | 
				
			||||
import time
 | 
				
			||||
from importlib import import_module
 | 
				
			||||
from multiprocessing import Pool
 | 
				
			||||
 | 
				
			||||
from mindspore.mindrecord import FileWriter
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def _exec_task(task_id, parallel_writer=True):
 | 
				
			||||
    """
 | 
				
			||||
    Execute task with specified task id
 | 
				
			||||
    """
 | 
				
			||||
    print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
 | 
				
			||||
    imagenet_iter = mindrecord_dict_data(task_id)
 | 
				
			||||
    batch_size = 2048
 | 
				
			||||
    transform_count = 0
 | 
				
			||||
    while True:
 | 
				
			||||
        data_list = []
 | 
				
			||||
        try:
 | 
				
			||||
            for _ in range(batch_size):
 | 
				
			||||
                data_list.append(imagenet_iter.__next__())
 | 
				
			||||
                transform_count += 1
 | 
				
			||||
            writer.write_raw_data(data_list, parallel_writer=parallel_writer)
 | 
				
			||||
            print("transformed {} record...".format(transform_count))
 | 
				
			||||
        except StopIteration:
 | 
				
			||||
            if data_list:
 | 
				
			||||
                writer.write_raw_data(data_list, parallel_writer=parallel_writer)
 | 
				
			||||
                print("transformed {} record...".format(transform_count))
 | 
				
			||||
            break
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == "__main__":
 | 
				
			||||
    parser = argparse.ArgumentParser(description='Mind record writer')
 | 
				
			||||
    parser.add_argument('--mindrecord_script', type=str, default="template",
 | 
				
			||||
                        help='path where script is saved')
 | 
				
			||||
 | 
				
			||||
    parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord",
 | 
				
			||||
                        help='written file name prefix')
 | 
				
			||||
 | 
				
			||||
    parser.add_argument('--mindrecord_partitions', type=int, default=1,
 | 
				
			||||
                        help='number of written files')
 | 
				
			||||
 | 
				
			||||
    parser.add_argument('--mindrecord_workers', type=int, default=8,
 | 
				
			||||
                        help='number of parallel workers')
 | 
				
			||||
 | 
				
			||||
    args = parser.parse_known_args()
 | 
				
			||||
 | 
				
			||||
    args, other_args = parser.parse_known_args()
 | 
				
			||||
 | 
				
			||||
    print(args)
 | 
				
			||||
    print(other_args)
 | 
				
			||||
 | 
				
			||||
    with open('mr_argument.pickle', 'wb') as file_handle:
 | 
				
			||||
        pickle.dump(other_args, file_handle)
 | 
				
			||||
 | 
				
			||||
    try:
 | 
				
			||||
        mr_api = import_module(args.mindrecord_script + '.mr_api')
 | 
				
			||||
    except ModuleNotFoundError:
 | 
				
			||||
        raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
 | 
				
			||||
 | 
				
			||||
    num_tasks = mr_api.mindrecord_task_number()
 | 
				
			||||
 | 
				
			||||
    print("Write mindrecord ...")
 | 
				
			||||
 | 
				
			||||
    mindrecord_dict_data = mr_api.mindrecord_dict_data
 | 
				
			||||
 | 
				
			||||
    # get number of files
 | 
				
			||||
    writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
 | 
				
			||||
 | 
				
			||||
    start_time = time.time()
 | 
				
			||||
 | 
				
			||||
    # set the header size
 | 
				
			||||
    try:
 | 
				
			||||
        header_size = mr_api.mindrecord_header_size
 | 
				
			||||
        writer.set_header_size(header_size)
 | 
				
			||||
    except AttributeError:
 | 
				
			||||
        print("Default header size: {}".format(1 << 24))
 | 
				
			||||
 | 
				
			||||
    # set the page size
 | 
				
			||||
    try:
 | 
				
			||||
        page_size = mr_api.mindrecord_page_size
 | 
				
			||||
        writer.set_page_size(page_size)
 | 
				
			||||
    except AttributeError:
 | 
				
			||||
        print("Default page size: {}".format(1 << 25))
 | 
				
			||||
 | 
				
			||||
    # get schema
 | 
				
			||||
    try:
 | 
				
			||||
        mindrecord_schema = mr_api.mindrecord_schema
 | 
				
			||||
    except AttributeError:
 | 
				
			||||
        raise RuntimeError("mindrecord_schema is not defined in mr_api.py.")
 | 
				
			||||
 | 
				
			||||
    # create the schema
 | 
				
			||||
    writer.add_schema(mindrecord_schema, "mindrecord_schema")
 | 
				
			||||
 | 
				
			||||
    # add the index
 | 
				
			||||
    try:
 | 
				
			||||
        index_fields = mr_api.mindrecord_index_fields
 | 
				
			||||
        writer.add_index(index_fields)
 | 
				
			||||
    except AttributeError:
 | 
				
			||||
        print("Default index fields: all simple fields are indexes.")
 | 
				
			||||
 | 
				
			||||
    writer.open_and_set_header()
 | 
				
			||||
 | 
				
			||||
    task_list = list(range(num_tasks))
 | 
				
			||||
 | 
				
			||||
    # set number of workers
 | 
				
			||||
    num_workers = args.mindrecord_workers
 | 
				
			||||
 | 
				
			||||
    if num_tasks < 1:
 | 
				
			||||
        num_tasks = 1
 | 
				
			||||
 | 
				
			||||
    if num_workers > num_tasks:
 | 
				
			||||
        num_workers = num_tasks
 | 
				
			||||
 | 
				
			||||
    if num_tasks > 1:
 | 
				
			||||
        with Pool(num_workers) as p:
 | 
				
			||||
            p.map(_exec_task, task_list)
 | 
				
			||||
    else:
 | 
				
			||||
        _exec_task(0, False)
 | 
				
			||||
 | 
				
			||||
    ret = writer.commit()
 | 
				
			||||
 | 
				
			||||
    os.remove("{}".format("mr_argument.pickle"))
 | 
				
			||||
 | 
				
			||||
    end_time = time.time()
 | 
				
			||||
    print("--------------------------------------------")
 | 
				
			||||
    print("END. Total time: {}".format(end_time - start_time))
 | 
				
			||||
    print("--------------------------------------------")
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								
					Loading…
					
					
				
		Reference in new issue