fix pyfunc pickle issue

pull/13061/head
xiefangqi 4 years ago
parent c9ce0d371a
commit 41f3e02e87

@ -121,9 +121,6 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
Status PyFuncOp::to_json(nlohmann::json *out_json) { Status PyFuncOp::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
auto package = pybind11::module::import("pickle");
auto module = package.attr("dumps");
args["tensor_op_params"] = module(py_func_ptr_, 0).cast<std::string>();
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>(); args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
args["is_python_front_end_op"] = true; args["is_python_front_end_op"] = true;
*out_json = args; *out_json = args;

@ -17,7 +17,6 @@ Functions to support dataset serialize and deserialize.
""" """
import json import json
import os import os
import pickle
import sys import sys
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -30,6 +29,9 @@ def serialize(dataset, json_filepath=""):
""" """
Serialize dataset pipeline into a json file. Serialize dataset pipeline into a json file.
Currently some python objects are not supported to be serialized.
For python function serialization of map operator, de.serialize will only return its function name.
Args: Args:
dataset (Dataset): the starting node. dataset (Dataset): the starting node.
json_filepath (str): a filepath where a serialized json file will be generated. json_filepath (str): a filepath where a serialized json file will be generated.
@ -56,6 +58,8 @@ def deserialize(input_dict=None, json_filepath=None):
""" """
Construct a de pipeline from a json file produced by de.serialize(). Construct a de pipeline from a json file produced by de.serialize().
Currently python function deserialization of map operator are not supported.
Args: Args:
input_dict (dict): a Python dictionary containing a serialized dataset graph input_dict (dict): a Python dictionary containing a serialized dataset graph
json_filepath (str): a path to the json file. json_filepath (str): a path to the json file.
@ -349,42 +353,42 @@ def construct_tensor_ops(operations):
op_params = op.get('tensor_op_params') op_params = op.get('tensor_op_params')
if op.get('is_python_front_end_op'): # check if it's a py_transform op if op.get('is_python_front_end_op'): # check if it's a py_transform op
result.append(pickle.loads(op_params.encode())) raise NotImplementedError("python function is not yet supported by de.deserialize().")
if op_name == "HwcToChw": op_name = "HWC2CHW"
if op_name == "UniformAug": op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name[:-2]):
op_name = op_name[:-2] # to remove op from the back of the name
op_class = getattr(op_module_trans, op_name, None)
else:
raise RuntimeError(op_name + " is not yet supported by deserialize().")
if op_params is None: # If no parameter is specified, call it directly
result.append(op_class())
else: else:
if op_name == "HwcToChw": op_name = "HWC2CHW" # Input parameter type cast
if op_name == "UniformAug": op_name = "UniformAugment" for key, val in op_params.items():
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] if key in ['center', 'fill_value']:
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
if hasattr(op_module_vis, op_name): op_params[key] = Inter(to_interpolation_mode(val))
op_class = getattr(op_module_vis, op_name, None) elif key in ['padding_mode']:
elif hasattr(op_module_trans, op_name[:-2]): op_params[key] = Border(to_border_mode(val))
op_name = op_name[:-2] # to remove op from the back of the name elif key in ['data_type']:
op_class = getattr(op_module_trans, op_name, None) op_params[key] = to_mstype(val)
else: elif key in ['image_batch_format']:
raise RuntimeError(op_name + " is not yet supported by deserialize().") op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
if op_params is None: # If no parameter is specified, call it directly op_params[key] = to_policy(val)
result.append(op_class()) elif key in ['transform', 'transforms']:
else: op_params[key] = construct_tensor_ops(val)
# Input parameter type cast
for key, val in op_params.items(): result.append(op_class(**op_params))
if key in ['center', 'fill_value']:
op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
op_params[key] = Inter(to_interpolation_mode(val))
elif key in ['padding_mode']:
op_params[key] = Border(to_border_mode(val))
elif key in ['data_type']:
op_params[key] = to_mstype(val)
elif key in ['image_batch_format']:
op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
op_params[key] = to_policy(val)
elif key in ['transform', 'transforms']:
op_params[key] = construct_tensor_ops(val)
result.append(op_class(**op_params))
return result return result

@ -375,7 +375,13 @@ def test_serdes_pyvision(remove_json_files=True):
py_vision.ToTensor() py_vision.ToTensor()
] ]
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"]) data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files) # Current python function derialization will be failed for pickle, so we disable this testcase
# as an exception testcase.
try:
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
assert False
except NotImplementedError as e:
assert "python function is not yet supported" in str(e)
def test_serdes_uniform_augment(remove_json_files=True): def test_serdes_uniform_augment(remove_json_files=True):

Loading…
Cancel
Save