|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from six.moves import reduce
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..layers import utils
|
|
|
|
@ -3457,19 +3458,6 @@ class Flatten(layers.Layer):
|
|
|
|
|
self.stop_axis = stop_axis
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
|
out = self._helper.create_variable_for_type_inference(input.dtype)
|
|
|
|
|
x_shape = self._helper.create_variable_for_type_inference(input.dtype)
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
dy_out, _ = core.ops.flatten_contiguous_range(
|
|
|
|
|
input, 'start_axis', self.start_axis, 'stop_axis',
|
|
|
|
|
self.stop_axis)
|
|
|
|
|
return dy_out
|
|
|
|
|
self._helper.append_op(
|
|
|
|
|
type="flatten_contiguous_range",
|
|
|
|
|
inputs={"X": input},
|
|
|
|
|
outputs={"Out": out,
|
|
|
|
|
"XShape": x_shape},
|
|
|
|
|
attrs={"start_axis": self.start_axis,
|
|
|
|
|
"stop_axis": self.stop_axis})
|
|
|
|
|
out = paddle.tensor.manipulation.flatten(
|
|
|
|
|
input, start_axis=self.start_axis, stop_axis=self.stop_axis)
|
|
|
|
|
return out
|
|
|
|
|