Merge branch 'feature/fix_param_hidden_in_pydp2' into feature/mnist_train_api

avx_docs
Yu Yang 8 years ago
commit ad93b8f964

@ -17,7 +17,7 @@ import random
from paddle.trainer.PyDataProvider2 import *
@provider(input_types=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
@provider(slots=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
def test_dense_no_seq(setting, filename):
for i in xrange(200):
yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)]

@ -232,7 +232,7 @@ def provider(input_types=None,
check=False,
check_fail_continue=False,
init_hook=None,
**kwargs):
**outter_kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
In this function, user only need to get each sample for some train/test
@ -318,10 +318,10 @@ def provider(input_types=None,
self.logger = logging.getLogger("")
self.logger.setLevel(logging.INFO)
self.input_types = None
if 'slots' in kwargs:
if 'slots' in outter_kwargs:
self.logger.warning('setting slots value is deprecated, '
'please use input_types instead.')
self.slots = kwargs['slots']
self.slots = outter_kwargs['slots']
self.slots = input_types
self.should_shuffle = should_shuffle

Loading…
Cancel
Save