You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/dygraph/dygraph_to_static/error.py

153 lines
5.5 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import traceback
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map
ERROR_DATA = "Error data about original source code information and traceback."
def attach_error_data(error, in_runtime=False):
"""
Attachs error data about original source code information and traceback to an error.
Args:
error(Exception): An native error.
in_runtime(bool): `error` is raised in runtime if in_runtime is True, otherwise in compile time
Returns:
An error attached data about original source code information and traceback.
"""
e_type, e_value, e_traceback = sys.exc_info()
tb = traceback.extract_tb(e_traceback)[1:]
error_data = ErrorData(e_type, e_value, tb, global_origin_info_map)
error_data.in_runtime = in_runtime
setattr(error, ERROR_DATA, error_data)
remove_static_file()
return error
def remove_static_file():
"""
Removes temporary files created during the transformation of dygraph to static graph.
"""
del_files = set()
for loc in global_origin_info_map:
static_filepath = loc[0]
del_files.add(static_filepath)
filename, extension = os.path.splitext(static_filepath)
del_files.add(filename + ".pyc")
for filepath in del_files:
if os.path.exists(filepath):
os.remove(filepath)
class TraceBackFrame(OriginInfo):
"""
Traceback frame information.
"""
def __init__(self, location, function_name, source_code):
self.location = location
self.function_name = function_name
self.source_code = source_code
class ErrorData(object):
"""
Error data attached to an exception which is raised in un-transformed code.
"""
def __init__(self, error_type, error_value, origin_traceback,
origin_info_map):
self.error_type = error_type
self.error_value = error_value
self.origin_traceback = origin_traceback
self.origin_info_map = origin_info_map
self.in_runtime = False
def create_exception(self):
message = self.create_message()
new_exception = self.error_type(message)
setattr(new_exception, ERROR_DATA, self)
return new_exception
def create_message(self):
"""
Creates a custom error message which includes trace stack with source code information of dygraph from user.
"""
message_lines = []
# Step1: Adds header message to prompt users that the following is the original information.
header_message = "In user code:"
message_lines.append(header_message)
message_lines.append("")
# Simplify error value to improve readability if error is raised in runtime
if self.in_runtime:
self._simplify_error_value()
message_lines.append(str(self.error_value))
return '\n'.join(message_lines)
# Step2: Optimizes stack information with source code information of dygraph from user.
for filepath, lineno, funcname, code in self.origin_traceback:
loc = Location(filepath, lineno)
dygraph_func_info = self.origin_info_map.get(loc.line_location,
None)
if dygraph_func_info:
# TODO(liym27): more information to prompt users that this is the original information.
# Replaces trace stack information about transformed static code with original dygraph code.
traceback_frame = self.origin_info_map[loc.line_location]
else:
traceback_frame = TraceBackFrame(loc, funcname, code)
message_lines.append(traceback_frame.formated_message())
# Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
error_message = " " * 4 + traceback.format_exception_only(
self.error_type, self.error_value)[0].strip("\n")
message_lines.append(error_message)
return '\n'.join(message_lines)
def _simplify_error_value(self):
"""
Simplifies error value to improve readability if error is raised in runtime.
NOTE(liym27): The op callstack information about transformed static code has been replaced with original dygraph code.
TODO(liym27):
1. Need a more robust way because the code of start_trace may change.
2. Set the switch to determine whether to simplify error_value
"""
assert self.in_runtime is True
error_value_lines = str(self.error_value).split("\n")
error_value_lines_strip = [mes.lstrip(" ") for mes in error_value_lines]
start_trace = "outputs = static_func(*inputs)"
start_idx = error_value_lines_strip.index(start_trace)
error_value_lines = error_value_lines[start_idx + 1:]
error_value_str = '\n'.join(error_value_lines)
self.error_value = self.error_type(error_value_str)