fix pyfunc pickle issue

pull/13061/head
xiefangqi 4 years ago
parent c9ce0d371a
commit 41f3e02e87

@ -121,9 +121,6 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
Status PyFuncOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
auto package = pybind11::module::import("pickle");
auto module = package.attr("dumps");
args["tensor_op_params"] = module(py_func_ptr_, 0).cast<std::string>();
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
args["is_python_front_end_op"] = true;
*out_json = args;

@ -17,7 +17,6 @@ Functions to support dataset serialize and deserialize.
"""
import json
import os
import pickle
import sys
import mindspore.common.dtype as mstype
@ -30,6 +29,9 @@ def serialize(dataset, json_filepath=""):
"""
Serialize dataset pipeline into a json file.
Currently some python objects are not supported to be serialized.
For python function serialization of map operator, de.serialize will only return its function name.
Args:
dataset (Dataset): the starting node.
json_filepath (str): a filepath where a serialized json file will be generated.
@ -56,6 +58,8 @@ def deserialize(input_dict=None, json_filepath=None):
"""
Construct a de pipeline from a json file produced by de.serialize().
Currently python function deserialization of map operator are not supported.
Args:
input_dict (dict): a Python dictionary containing a serialized dataset graph
json_filepath (str): a path to the json file.
@ -349,8 +353,8 @@ def construct_tensor_ops(operations):
op_params = op.get('tensor_op_params')
if op.get('is_python_front_end_op'): # check if it's a py_transform op
result.append(pickle.loads(op_params.encode()))
else:
raise NotImplementedError("python function is not yet supported by de.deserialize().")
if op_name == "HwcToChw": op_name = "HWC2CHW"
if op_name == "UniformAug": op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]

@ -375,7 +375,13 @@ def test_serdes_pyvision(remove_json_files=True):
py_vision.ToTensor()
]
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
# Current python function derialization will be failed for pickle, so we disable this testcase
# as an exception testcase.
try:
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
assert False
except NotImplementedError as e:
assert "python function is not yet supported" in str(e)
def test_serdes_uniform_augment(remove_json_files=True):

Loading…
Cancel
Save