|
|
|
@ -17,6 +17,8 @@ from __future__ import print_function
|
|
|
|
|
import copy
|
|
|
|
|
import itertools
|
|
|
|
|
import six
|
|
|
|
|
import sys
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating
|
|
|
|
|
from . import unique_name
|
|
|
|
@ -46,23 +48,43 @@ class LayerHelper(object):
|
|
|
|
|
def startup_program(self):
|
|
|
|
|
return default_startup_program()
|
|
|
|
|
|
|
|
|
|
def _np_to_variable(self, x):
|
|
|
|
|
tensor = core.LoDTensor()
|
|
|
|
|
sys.stderr.write('%s %s\n' % (tensor, x))
|
|
|
|
|
tensor.set(x, core.CPUPlace())
|
|
|
|
|
return Variable(
|
|
|
|
|
self.main_program.current_block(),
|
|
|
|
|
type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
name=None,
|
|
|
|
|
shape=x.shape,
|
|
|
|
|
dtype=x.dtype)
|
|
|
|
|
|
|
|
|
|
def to_variable(self, x):
|
|
|
|
|
if isinstance(x, Variable):
|
|
|
|
|
return x
|
|
|
|
|
elif isinstance(x, np.ndarray):
|
|
|
|
|
return self._np_to_variable(x)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("inputs wrong type %s\n" % x)
|
|
|
|
|
|
|
|
|
|
def to_variables(self, inputs):
|
|
|
|
|
if isinstance(inputs, list) or isinstance(inputs, tuple):
|
|
|
|
|
return [self._to_variable(x) for x in inputs]
|
|
|
|
|
else:
|
|
|
|
|
return [self._to_variable(inputs)]
|
|
|
|
|
|
|
|
|
|
def append_op(self, *args, **kwargs):
|
|
|
|
|
return self.main_program.current_block().append_op(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def multiple_input(self, input_param_name='input'):
|
|
|
|
|
inputs = self.kwargs.get(input_param_name, [])
|
|
|
|
|
type_error = TypeError(
|
|
|
|
|
"Input of {0} layer should be Variable or sequence of Variable".
|
|
|
|
|
format(self.layer_type))
|
|
|
|
|
if isinstance(inputs, Variable):
|
|
|
|
|
inputs = [inputs]
|
|
|
|
|
elif not isinstance(inputs, list) and not isinstance(inputs, tuple):
|
|
|
|
|
raise type_error
|
|
|
|
|
ret = []
|
|
|
|
|
if isinstance(inputs, list) or isinstance(inputs, tuple):
|
|
|
|
|
for inp in inputs:
|
|
|
|
|
ret.append(self.to_variable(inp))
|
|
|
|
|
else:
|
|
|
|
|
for each in inputs:
|
|
|
|
|
if not isinstance(each, Variable):
|
|
|
|
|
raise type_error
|
|
|
|
|
return inputs
|
|
|
|
|
ret.append(self.to_variable(inputs))
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def input(self, input_param_name='input'):
|
|
|
|
|
inputs = self.multiple_input(input_param_name)
|
|
|
|
|