[Dy2Stat]Support list pop (#24250)

* Replace dygraph_to_static_func with @declarative or program_translator.get_func in test_list.py

* Add comments in ConditionalBlock.

* Support list pop last item. 

* Support pop the i-th item. 

* Support an empty tensor array as Input in assign op and set the kernel type is float.
revert-24314-dev/fix_err_msg
liym27 6 years ago committed by GitHub
parent c78da18db4
commit ac9a7eeea4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1326,7 +1326,7 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
"The Input Variable(%s) of %s Op used to determine kernel data type "
"is empty or not LoDTensor or SelectedRows.",
"is empty or not LoDTensor or SelectedRows or LoDTensorArray.",
name, Type());
return data_type;
}

@ -476,14 +476,14 @@ TEST(IndicateVarDataTypeTest, other) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_other_data_type_test");
BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs());
BuildVar("Other", {"lod_rank_table_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("lod_tensor_array_1");
var->GetMutable<paddle::framework::LoDTensorArray>();
auto* var = scope.Var("lod_rank_table_1");
var->GetMutable<paddle::framework::LoDRankTable>();
bool caught = false;
try {
@ -491,11 +491,13 @@ TEST(IndicateVarDataTypeTest, other) {
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of "
"indicate_other_data_type_test Op used to "
"determine kernel data type "
"is empty or not LoDTensor or SelectedRows") !=
std::string::npos);
EXPECT_TRUE(
ex_msg.find(
"The Input Variable(Other) of "
"indicate_other_data_type_test Op used to "
"determine kernel data type "
"is empty or not LoDTensor or SelectedRows or LoDTensorArray") !=
std::string::npos);
}
ASSERT_TRUE(caught);
}

@ -58,6 +58,17 @@ class AssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const framework::Variable *var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = var->Get<framework::LoDTensorArray>();
// NOTE(liym27): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());

@ -47,7 +47,8 @@ class SliceOp : public framework::OperatorWithKernel {
// the output shape is determined by SliceKernel:Compute in runtime.
return;
} else {
// NOTE: A better way is needed to get accurate dims of tensor array.
// NOTE(liym27): A better way is needed to get accurate dims of tensor
// array.
// The resulted dim of GetInputDim("Input") is the dim of the
// last item written into TensorArray "Input". Maybe it's a bug to fix.
ctx->SetOutputDim("Out", ctx->GetInputDim("Input"));

@ -32,6 +32,9 @@ from .program_translator import *
from . import convert_call_func
from .convert_call_func import *
from . import list_transformer
from .list_transformer import *
__all__ = []
__all__ += ast_transformer.__all__
__all__ += loop_transformer.__all__
@ -39,3 +42,4 @@ __all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__
__all__ += program_translator.__all__
__all__ += convert_call_func.__all__
__all__ += list_transformer.__all__

@ -14,10 +14,96 @@
from __future__ import print_function
import gast
import astor
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform, ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform
from paddle.fluid.framework import core, default_main_program, Variable
from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, cast, fill_constant, slice
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
__all__ = ['convert_list_pop']
def create_array_in_parent_blcok(null_array):
# TODO(liym27): Create a null tensor_array with the same name in parent block to avoid a bug in control flow,
# because in `null_array = create_array("float32")`, `null_array` is not a output of a real OP.
# See class ConditionalBlock for details.
prog = default_main_program()
parent_idx = prog.current_block().parent_idx
while parent_idx != -1:
parent_block = prog.block(parent_idx)
parent_block.create_var(
name=null_array.name,
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype="float32")
parent_idx = parent_block.parent_idx
# TODO(liym27): A better way to slice tensor array.
# Maybe support start == end for slice op.
def slice_tensor_array(array, start, end):
end = cast(end, "int32")
def true_fn():
null_array = create_array("float32")
create_array_in_parent_blcok(null_array)
return null_array
def false_fn(array, start, end):
new_array = slice(array, starts=[start], ends=[end], axes=[0])
return new_array
new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end))
return new_array
def tensor_array_pop(array, idx):
assert isinstance(idx, int)
def cond(i, new_array):
return less_than(i, arr_len)
def body(i, new_array):
item = array_read(array=array, i=i)
array_write(item, array_length(new_array), new_array)
i = increment(i)
return i, new_array
arr_len = array_length(array)
if idx < 0:
idx = idx + arr_len
else:
idx = fill_constant(shape=[1], dtype="int64", value=idx)
pop_item = array_read(array, idx)
new_array = slice_tensor_array(array, 0, idx)
i = idx + 1
_, new_array = while_loop(cond, body, [i, new_array])
assign(input=new_array, output=array)
return pop_item
def convert_list_pop(target, idx=None):
"""
Convert list pop.
"""
if idx is None:
idx = -1
is_variable = isinstance(target, Variable)
if is_variable:
is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
if is_variable and is_tensor_array:
result = tensor_array_pop(target, idx)
else:
result = target.pop(idx)
return result
class ListTransformer(gast.NodeTransformer):
@ -45,12 +131,21 @@ class ListTransformer(gast.NodeTransformer):
self.visit(self.root)
self.replace_list_with_tensor_array(self.root)
def visit_Call(self, node):
if isinstance(node.func, gast.Attribute):
func_name = node.func.attr
if func_name == "pop":
node = self._replace_list_pop(node)
return node
def visit_Assign(self, node):
if self._update_list_name_to_updated(node):
return node
if self._need_to_array_write_node(node):
return self._transform_slice_to_tensor_write(node)
self.generic_visit(node)
return node
def visit_If(self, node):
@ -203,3 +298,21 @@ class ListTransformer(gast.NodeTransformer):
self.list_name_to_updated[target_id] == False:
del self.list_name_to_updated[target_id]
return False
def _replace_list_pop(self, node):
assert isinstance(node, gast.Call)
assert isinstance(node.func, gast.Attribute)
target_node = node.func.value
target_str = ast_to_source_code(target_node).strip()
if node.args:
idx_node = node.args[0]
idx_str = ast_to_source_code(idx_node).strip()
else:
idx_str = "None"
new_call_str = "fluid.dygraph.dygraph_to_static.convert_list_pop({}, {})".format(
target_str, idx_str)
new_call_node = gast.parse(new_call_str).body[0].value
return new_call_node

@ -344,10 +344,10 @@ class ProgramTranslator(object):
prog_trans.enable(False)
x = np.ones([1, 2])
# The declarative is disabled so the func is run in dygraph
# The declarative is disabled so the func is run in dygraph
with fluid.dygraph.guard():
print(func(x).numpy()) # [[2. 2.]]
"""
check_type(enable_declarative, "enable_declarative", bool,
"ProgramTranslator.enable")
@ -361,7 +361,7 @@ class ProgramTranslator(object):
Args:
dygraph_func (callable): the dygraph function.
*args, **kwargs : the input argument of dygraph_func.
*args, **kwargs : the input argument of dygraph_func.
Returns:
VarBase or tuple of VarBase: the dygraph VarBase containing digital
@ -763,7 +763,7 @@ class ProgramTranslator(object):
assert abs(index_of_loss) < len(outputs), \
"index_of_loss: {} shall not exceed the length of outputs: {}.".format(
index_of_loss, len(outputs))
index_of_loss, len(outputs))
loss_var = outputs[index_of_loss]
check_type(loss_var, "loss_var", framework.Variable,

@ -2001,6 +2001,9 @@ class ConditionalBlock(object):
intermediate = set()
params = set()
# NOTE: Here assumes that all variables are input or output of Ops,
# but some variables are created without appendding a real op.
# For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
for each_op in inside_block.ops:
assert isinstance(each_op, Operator)
for iname in each_op.input_names:

Loading…
Cancel
Save