|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
|
import sys
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
@ -21,33 +22,46 @@ from paddle.fluid import framework
|
|
|
|
|
__all__ = ['PyLayer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def trace_scope(scope, block):
|
|
|
|
|
tmp_scope = framework._imperative_tracer().scope
|
|
|
|
|
tmp_block = framework._imperative_tracer().block
|
|
|
|
|
framework._imperative_tracer().scope = scope
|
|
|
|
|
framework._imperative_tracer().block = block
|
|
|
|
|
yield
|
|
|
|
|
framework._imperative_tracer().scope = tmp_scope
|
|
|
|
|
framework._imperative_tracer().block = tmp_block
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PyLayer(core.Layer):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._scope = core.Scope()
|
|
|
|
|
self._block = framework.default_main_program().current_block()
|
|
|
|
|
|
|
|
|
|
def __call__(self, inputs):
|
|
|
|
|
if not isinstance(inputs, list) and not isinstance(inputs, tuple):
|
|
|
|
|
inputs = [inputs]
|
|
|
|
|
|
|
|
|
|
var_inputs = []
|
|
|
|
|
for x in inputs:
|
|
|
|
|
if isinstance(x, np.ndarray):
|
|
|
|
|
tensor = core.LoDTensor()
|
|
|
|
|
tensor.set(x, core.CPUPlace())
|
|
|
|
|
x = framework.Variable(
|
|
|
|
|
framework.default_main_program().current_block(),
|
|
|
|
|
type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
name=None,
|
|
|
|
|
shape=x.shape,
|
|
|
|
|
dtype=x.dtype)
|
|
|
|
|
elif not isinstance(x, framework.Variable):
|
|
|
|
|
raise ValueError("not var or ndarray %s" % type(x))
|
|
|
|
|
self._scope.var(x.name)
|
|
|
|
|
var_inputs.append(x)
|
|
|
|
|
outputs = self.forward(var_inputs)
|
|
|
|
|
for out in outputs:
|
|
|
|
|
self._scope.var(out.name)
|
|
|
|
|
return outputs
|
|
|
|
|
with trace_scope(self._scope, self._block.desc):
|
|
|
|
|
if not isinstance(inputs, list) and not isinstance(inputs, tuple):
|
|
|
|
|
inputs = [inputs]
|
|
|
|
|
|
|
|
|
|
var_inputs = []
|
|
|
|
|
for x in inputs:
|
|
|
|
|
if isinstance(x, np.ndarray):
|
|
|
|
|
py_var = framework.Variable(
|
|
|
|
|
self._block,
|
|
|
|
|
type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
name=None,
|
|
|
|
|
shape=x.shape,
|
|
|
|
|
dtype=x.dtype)
|
|
|
|
|
var = self._scope.var(py_var.name)
|
|
|
|
|
tensor = var.get_tensor()
|
|
|
|
|
tensor.set_float(x, core.CPUPlace())
|
|
|
|
|
var_inputs.append(py_var)
|
|
|
|
|
elif isinstance(x, framework.Variable):
|
|
|
|
|
var_inputs.append(x)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("not var or ndarray %s" % type(x))
|
|
|
|
|
outputs = self.forward(var_inputs)
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
|
print("at python.")
|
|
|
|
|