[Dy2Static] Add convert_ifelse to run the transformed code dynamically (#24866)

* cast var in convert_logical_XX. 

* Add convert_ifelse function in convert_operators.py  

* Add logical_transformer. Remove LogicalTransformer from loop_transformer.py 

* Revert modified tests in PR24799(convert_while_stmt). 

* Comment and modify code that doesn't support `return` statement. 

* Remove unnecessary class: MergeAssignTransformer, NodeTestTransformer and IfConditionVisitor in ifelse_transformer.
revert-24981-add_device_attr_for_regulization
liym27 6 years ago committed by GitHub
parent ef9b36873d
commit a9dca5805a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,6 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicAp
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
@ -75,6 +76,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform break/continue in loops
BreakContinueTransformer(node_wrapper).transform()
# Transform logical and/or/not
LogicalTransformer(node_wrapper).transform()
# Transform for loop and while loop
LoopTransformer(node_wrapper).transform()

@ -13,8 +13,9 @@
# limitations under the License.
from paddle.fluid.framework import Variable
from paddle.fluid.layers import control_flow, logical_and, logical_or, logical_not
from paddle.fluid.layers import control_flow, logical_and, logical_or, logical_not, cast
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.data_feeder import convert_dtype
def convert_while_loop(cond, body, loop_vars):
@ -30,6 +31,8 @@ def convert_while_loop(cond, body, loop_vars):
A list or tuple of variables which returned by ``body`` .
"""
# NOTE: It may be slower if cond is very expensive, but usually cond is just O(1).
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond(*loop_vars)
if isinstance(pred, Variable):
loop_vars = _run_paddle_while_loop(cond, body, loop_vars)
@ -73,6 +76,8 @@ def convert_logical_and(x, y):
def _run_paddle_logical_and(x, y):
x = cast_bool_if_necessary(x)
y = cast_bool_if_necessary(y)
return logical_and(x, y)
@ -104,6 +109,8 @@ def convert_logical_or(x, y):
def _run_paddle_logical_or(x, y):
x = cast_bool_if_necessary(x)
y = cast_bool_if_necessary(y)
return logical_or(x, y)
@ -131,8 +138,45 @@ def convert_logical_not(x):
def _run_paddle_logical_not(x):
x = cast_bool_if_necessary(x)
return logical_not(x)
def _run_py_logical_not(x):
return not x
def convert_ifelse(pred, true_fn, false_fn):
"""
A function representation of a Python ``if/else`` statement.
Args:
pred(bool|Variable): A boolean variable which determines whether to return the result of ``true_fn`` or ``false_fn`` .
true_fn(callable): A callable to be performed if ``pred`` is true.
false_fn(callable): A callable to be performed if ``pred`` is false.
Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
"""
if isinstance(pred, Variable):
return _run_paddle_cond(pred, true_fn, false_fn)
else:
return _run_py_ifelse(pred, true_fn, false_fn)
def _run_paddle_cond(pred, true_fn, false_fn):
pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, true_fn, false_fn)
def _run_py_ifelse(pred, true_fn, false_fn):
return true_fn() if pred else false_fn()
def cast_bool_if_necessary(var):
assert isinstance(var, Variable)
if convert_dtype(var.dtype) not in ['bool']:
var = cast(var, dtype="bool")
return var

@ -0,0 +1,74 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
class LogicalTransformer(gast.NodeTransformer):
"""
Transform python boolean op into Paddle logical op
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
return self.visit(self.root)
def visit_UnaryOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand)
new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_not({})".format(
arg)
# NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node
return node
def visit_BoolOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.And):
new_node = self._create_bool_op_node(node.values, 'and')
elif isinstance(node.op, gast.Or):
new_node = self._create_bool_op_node(node.values, 'or')
else:
raise TypeError(
"Only supports and/or syntax in control flow if statement.")
return new_node
def _create_bool_op_node(self, nodes, api_type):
assert len(
nodes
) > 1, "The length of BoolOp should be at least 2, but received {}.".format(
len(nodes))
if len(nodes) > 2:
# Creates logic_and/logic_or node recursively.
pre_logic_node = self._create_bool_op_node(nodes[:2], api_type)
if len(nodes[2:]) == 1:
post_logic_node = nodes[2]
else:
post_logic_node = self._create_bool_op_node(nodes[2:], api_type)
nodes = [pre_logic_node] + [post_logic_node]
args = [ast_to_source_code(child) for child in nodes]
new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_{}(x={}, y={})".format(
api_type, args[0], args[1])
# NOTE: gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node

@ -70,61 +70,6 @@ def create_while_node(condition_name, body_name, loop_var_names):
return assign_node
class LogicalOpTransformer(gast.NodeTransformer):
"""
Transform python boolean op into Paddle logical op
"""
def __init__(self, node):
self.root = node
def transform(self):
return self.visit(self.root)
def visit_UnaryOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand)
new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_not({})".format(
arg)
# gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node
return node
def visit_BoolOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.And):
new_node = self._create_bool_op_node(node.values, 'and')
elif isinstance(node.op, gast.Or):
new_node = self._create_bool_op_node(node.values, 'or')
else:
raise TypeError(
"Only supports and/or syntax in control flow if statement.")
return new_node
def _create_bool_op_node(self, nodes, api_type):
assert len(
nodes
) > 1, "The length of BoolOp should be at least 2, but received {}.".format(
len(nodes))
if len(nodes) > 2:
# Creates logic_and/logic_or node recursively.
pre_logic_node = self._create_bool_op_node(nodes[:2], api_type)
if len(nodes[2:]) == 1:
post_logic_node = nodes[2]
else:
post_logic_node = self._create_bool_op_node(nodes[2:], api_type)
nodes = [pre_logic_node] + [post_logic_node]
args = [ast_to_source_code(child) for child in nodes]
new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_{}(x={}, y={})".format(
api_type, args[0], args[1])
# gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node
class NameVisitor(gast.NodeVisitor):
'''
Analysis name liveness for loop transformer
@ -560,9 +505,6 @@ class LoopTransformer(gast.NodeTransformer):
if "." not in name:
new_stmts.append(create_static_variable_gast_node(name))
logical_op_transformer = LogicalOpTransformer(node.test)
cond_value_node = logical_op_transformer.transform()
condition_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_CONDITION_PREFIX),
args=gast.arguments(
@ -579,10 +521,11 @@ class LoopTransformer(gast.NodeTransformer):
kw_defaults=None,
kwarg=None,
defaults=[]),
body=[gast.Return(value=cond_value_node)],
body=[gast.Return(value=node.test)],
decorator_list=[],
returns=None,
type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(condition_func_node)

@ -14,7 +14,6 @@
from __future__ import print_function
import six
import paddle.fluid as fluid
@ -303,14 +302,6 @@ def if_tensor_case(x):
# It is equivalent to `if mean != 0`
if mean:
for i in range(0, 10):
# TODO(liym27): Delete it if the type of parameter `i` can be resolved in "if" stmt
if six.PY2:
i = fluid.layers.fill_constant(
shape=[1], value=i, dtype="int32")
else:
i = fluid.layers.fill_constant(
shape=[1], value=i, dtype="int64")
if i > 5:
x += 1
break

@ -90,15 +90,14 @@ def test_break_in_while(x):
def test_break_continue_in_for(x):
x = fluid.dygraph.to_variable(x)
# TODO(liym27): Uncomment code after "if" statement can be transformed correctly.
# for i in range(1, 10, 1):
# if i <= 4:
# x += 1
# continue
# else:
# x += 10010
# break
# x += 10086
for i in range(1, 10, 1):
if i <= 4:
x += 1
continue
else:
x += 10010
break
x += 10086
a = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
for i in range(1, 10, 1):
@ -117,16 +116,15 @@ def test_break_continue_in_for(x):
def test_for_in_else(x):
x = fluid.dygraph.to_variable(x)
# TODO(liym27): Uncomment code after "if" statement can be transformed correctly.
# # Case 1:
# if False:
# pass
# else:
# for i in range(0, 10):
# if i > 5:
# x += 1
# break
# x += i
# Case 1:
if False:
pass
else:
for i in range(0, 10):
if i > 5:
x += 1
break
x += i
# Case 2:
if False:

@ -18,10 +18,9 @@ import unittest
import textwrap
import gast
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfConditionVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
class TestGetNameIds(unittest.TestCase):
@ -125,12 +124,7 @@ class TestIsControlFlowIf(unittest.TestCase):
node = gast.parse(code)
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertFalse(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(new_node == node_test)
self.assertTrue(len(assign_nodes) == 0)
self.assertFalse(is_control_flow_to_transform(node_test))
def test_expr(self):
# node is not ast.Compare
@ -140,12 +134,7 @@ class TestIsControlFlowIf(unittest.TestCase):
# x is a Tensor.
node = gast.parse("a + x.numpy()")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(len(assign_nodes) == 0)
self.assertTrue(is_control_flow_to_transform(node_test))
def test_is_None(self):
self.check_false_case("x is None")
@ -160,47 +149,25 @@ class TestIsControlFlowIf(unittest.TestCase):
node = gast.parse("fluid.layers.sum(x) and 2>1")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# Transformation result:
# bool_tensor_0 = fluid.layers.cast(x=fluid.layers.sum(x), dtype='bool')
# bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
# logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=bool_tensor_1)
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(len(assign_nodes) == 3)
self.assertTrue(is_control_flow_to_transform(node_test))
def test_if(self):
node = gast.parse("x.numpy()[1] > 1")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(len(assign_nodes) == 0)
self.assertTrue(is_control_flow_to_transform(node_test))
def test_if_with_and(self):
node = gast.parse("x and 1 < x.numpy()[1]")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(isinstance(new_node, gast.Name))
self.assertTrue(len(assign_nodes) == 2)
self.assertTrue(is_control_flow_to_transform(node_test))
def test_if_with_or(self):
node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(isinstance(new_node, gast.Name))
self.assertTrue(len(assign_nodes) == 2)
self.assertTrue(is_control_flow_to_transform(node_test))
def test_shape(self):
code = """
@ -214,12 +181,9 @@ class TestIsControlFlowIf(unittest.TestCase):
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(new_node == test_node)
self.assertTrue(len(assign_nodes) == 0)
self.assertTrue(
is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_shape_with_andOr(self):
code = """
@ -233,18 +197,9 @@ class TestIsControlFlowIf(unittest.TestCase):
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
self.assertTrue(if_visitor.is_control_flow())
new_node, assign_nodes = if_visitor.transform()
# transformation result:
# bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None))
# logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=batch_size[0] > 16)
# bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
# logic_or_0 = fluid.layers.logical_or(x=logic_and_0, y=bool_tensor_1)
self.assertTrue(isinstance(new_node, gast.Name))
self.assertTrue(len(assign_nodes) == 4)
self.assertTrue(
is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_paddle_api(self):
code = """
@ -257,13 +212,9 @@ class TestIsControlFlowIf(unittest.TestCase):
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(new_node == test_node)
self.assertTrue(len(assign_nodes) == 0)
self.assertTrue(
is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_paddle_api_with_andOr(self):
code_or = """
@ -284,43 +235,34 @@ class TestIsControlFlowIf(unittest.TestCase):
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
self.assertTrue(if_visitor.is_control_flow())
new_node, assign_nodes = if_visitor.transform()
# Transformation result:
# bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
# bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None))
# logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=fluid.layers.shape(x)[0] > 16)
# logic_and_1 = fluid.layers.logical_and(x=logic_and_0, y=bool_tensor_1) for code_and
# logic_or_0= fluid.layers.logical_or(x=logic_and_0, y=bool_tensor_1) for code_and
self.assertTrue(isinstance(new_node, gast.Name))
self.assertTrue(len(assign_nodes) == 4)
self.assertTrue(
is_control_flow_to_transform(test_node,
static_analysis_visitor))
def test_with_node_var_type_map(self):
node = gast.parse("x > 1")
node_test = node.body[0].value
# if x is a Tensor
node_var_type_map = {"x": {NodeVarType.TENSOR}}
visitor = IsControlFlowVisitor(
node_test, node_var_type_map=node_var_type_map)
self.assertTrue(visitor.transform())
var_name_to_type = {"x": {NodeVarType.TENSOR}}
self.assertTrue(
is_control_flow_to_transform(
node_test, var_name_to_type=var_name_to_type))
# if x is not a Tensor
node_var_type_map = {"x": {NodeVarType.NUMPY_NDARRAY}}
visitor = IsControlFlowVisitor(
node_test, node_var_type_map=node_var_type_map)
self.assertFalse(visitor.transform())
var_name_to_type = {"x": {NodeVarType.NUMPY_NDARRAY}}
self.assertFalse(
is_control_flow_to_transform(
node_test, var_name_to_type=var_name_to_type))
def test_raise_error(self):
node = "a + b"
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, IfConditionVisitor(node))
self.assertRaises(TypeError, is_control_flow_to_transform(node))
self.assertTrue(
"Type of input node should be gast.AST" in str(e.exception))
"The type of input node must be gast.AST" in str(e.exception))
if __name__ == '__main__':

@ -65,6 +65,20 @@ def call_lambda_with_ifExpr(x):
return out
def call_lambda_with_ifExpr2(x):
x = fluid.dygraph.to_variable(x)
add_func = lambda x: x + 1
y = fluid.layers.mean(x)
# NOTE: y is Variable, but z<2 is python bool value
z = 0
out = add_func(y) if y or z < 2 else (lambda x: x**2)(y)
return out
class TestLambda(unittest.TestCase):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
@ -76,7 +90,7 @@ class TestLambda(unittest.TestCase):
def init_func(self):
self.dyfuncs = [
call_lambda_as_func, call_lambda_directly, call_lambda_in_func,
call_lambda_with_ifExpr
call_lambda_with_ifExpr, call_lambda_with_ifExpr2
]
def run_static(self, func):

@ -157,8 +157,6 @@ def test_list_pop_in_while_loop(x, iter_num):
a = []
i = 0
# TODO(liym27): Delete it if the type of parameter `i` can be resolved in "if" stmt
i = fluid.layers.fill_constant(shape=[1], value=i, dtype="int32")
while i < iter_num:
a.append(x + i)
i += 1

@ -106,9 +106,13 @@ class MNIST(fluid.dygraph.Layer):
acc = fluid.layers.accuracy(input=x, label=label)
loss = fluid.layers.cross_entropy(x, label)
avg_loss = fluid.layers.mean(loss)
return x, acc, avg_loss
else:
return x
# TODO: Uncomment code after "return" statement can be transformed correctly.
# return x, acc, avg_loss
# else:
# return x
return x, acc, avg_loss
def inference(self, inputs):
x = self._simple_img_conv_pool_1(inputs)

@ -61,6 +61,7 @@ def get_source_code(func):
class StaticCode1():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
def true_fn_0(x_v):
x_v = x_v - 1
@ -70,35 +71,56 @@ class StaticCode1():
x_v = x_v + 1
return x_v
x_v = fluid.layers.cond(
x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(x_v)
)
if label is not None:
def true_fn_1(label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
return loss
return
def false_fn_1():
return
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_1)(label, x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_1)())
return x_v
class StaticCode2():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
def true_fn_1(x_v):
def true_fn_2(x_v):
x_v = x_v - 1
return x_v
def false_fn_1(x_v):
def false_fn_2(x_v):
x_v = x_v + 1
return x_v
x_v = fluid.layers.cond(
x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_1)(x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_1)(x_v)
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_2)(x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_2)(x_v)
)
if label is not None:
def true_fn_3(label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
return loss
return
def false_fn_3():
return
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_3)(label, x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_3)())
return x_v

@ -18,7 +18,6 @@ import unittest
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from test_program_translator import get_source_code

@ -314,19 +314,21 @@ class YOLOv3(fluid.dygraph.Layer):
scores, perm=[0, 2, 1]))
self.downsample //= 2
if not self.is_train:
# get pred
yolo_boxes = fluid.layers.concat(self.boxes, axis=1)
yolo_scores = fluid.layers.concat(self.scores, axis=2)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=cfg.valid_thresh,
nms_top_k=cfg.nms_topk,
keep_top_k=cfg.nms_posk,
nms_threshold=cfg.nms_thresh,
background_label=-1)
return pred
else:
return sum(self.losses)
# TODO(liym27): Uncomment code after "return" statement can be transformed correctly.
# if not self.is_train:
# # get pred
# yolo_boxes = fluid.layers.concat(self.boxes, axis=1)
# yolo_scores = fluid.layers.concat(self.scores, axis=2)
#
# pred = fluid.layers.multiclass_nms(
# bboxes=yolo_boxes,
# scores=yolo_scores,
# score_threshold=cfg.valid_thresh,
# nms_top_k=cfg.nms_topk,
# keep_top_k=cfg.nms_posk,
# nms_threshold=cfg.nms_thresh,
# background_label=-1)
# return pred
# else:
# return sum(self.losses)
return sum(self.losses)

Loading…
Cancel
Save