|
|
@ -12,8 +12,8 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
from py_paddle import swig_paddle, DataProviderWrapperConverter
|
|
|
|
from py_paddle import swig_paddle, DataProviderConverter
|
|
|
|
from paddle.trainer.PyDataProviderWrapper import DenseSlot
|
|
|
|
from paddle.trainer.PyDataProvider2 import dense_vector
|
|
|
|
from paddle.trainer.config_parser import parse_config
|
|
|
|
from paddle.trainer.config_parser import parse_config
|
|
|
|
|
|
|
|
|
|
|
|
TEST_DATA = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
TEST_DATA = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
@ -89,12 +89,12 @@ TEST_DATA = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
|
conf = parse_config("./mnist_model/trainer_config.conf.norm", "")
|
|
|
|
conf = parse_config("./mnist_model/trainer_config.py", "")
|
|
|
|
print conf.data_config.load_data_args
|
|
|
|
print conf.data_config.load_data_args
|
|
|
|
network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
|
|
|
|
network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
|
|
|
|
assert isinstance(network, swig_paddle.GradientMachine) # For code hint.
|
|
|
|
assert isinstance(network, swig_paddle.GradientMachine) # For code hint.
|
|
|
|
network.loadParameters("./mnist_model/")
|
|
|
|
network.loadParameters("./mnist_model/")
|
|
|
|
converter = DataProviderWrapperConverter(False, [DenseSlot(784)])
|
|
|
|
converter = DataProviderConverter([dense_vector(784)])
|
|
|
|
inArg = converter(TEST_DATA)
|
|
|
|
inArg = converter(TEST_DATA)
|
|
|
|
print network.forwardTest(inArg)
|
|
|
|
print network.forwardTest(inArg)
|
|
|
|
|
|
|
|
|
|
|
|