|
|
|
@ -13,16 +13,11 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
import six
|
|
|
|
|
import os
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from paddle.fluid.optimizer import SGDOptimizer
|
|
|
|
|
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, BatchNorm, Embedding, GRUUnit
|
|
|
|
|
from paddle.fluid.dygraph.base import to_variable
|
|
|
|
|
from test_imperative_base import new_program_scope
|
|
|
|
@ -37,13 +32,13 @@ class Config(object):
|
|
|
|
|
# size for word embedding
|
|
|
|
|
word_vector_dim = 128
|
|
|
|
|
# max length for label padding
|
|
|
|
|
max_length = 15
|
|
|
|
|
max_length = 5
|
|
|
|
|
# optimizer setting
|
|
|
|
|
LR = 1.0
|
|
|
|
|
learning_rate_decay = None
|
|
|
|
|
|
|
|
|
|
# batch size to train
|
|
|
|
|
batch_size = 32
|
|
|
|
|
batch_size = 16
|
|
|
|
|
# class number to classify
|
|
|
|
|
num_classes = 481
|
|
|
|
|
|
|
|
|
@ -445,10 +440,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
|
|
|
|
|
(i - 1) * Config.max_length,
|
|
|
|
|
i * Config.max_length,
|
|
|
|
|
dtype='int64').reshape([1, Config.max_length])))
|
|
|
|
|
#if Config.use_gpu:
|
|
|
|
|
# place = fluid.CUDAPlace(0)
|
|
|
|
|
#else:
|
|
|
|
|
# place = fluid.CPUPlace()
|
|
|
|
|
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
@ -461,10 +453,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
|
|
|
|
|
[50000], [Config.LR, Config.LR * 0.01])
|
|
|
|
|
else:
|
|
|
|
|
learning_rate = Config.LR
|
|
|
|
|
#optimizer = fluid.optimizer.Adadelta(learning_rate=learning_rate,
|
|
|
|
|
# epsilon=1.0e-6, rho=0.9)
|
|
|
|
|
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
|
|
|
|
|
# place = fluid.CPUPlace()
|
|
|
|
|
dy_param_init_value = {}
|
|
|
|
|
for param in ocr_attention.parameters():
|
|
|
|
|
dy_param_init_value[param.name] = param.numpy()
|
|
|
|
|