scaffolding for the new Fluid API

trainerSaveLoadParams
Helin Wang 7 years ago
parent 738585476d
commit a2ffbd5326

@ -20,6 +20,16 @@ from framework import *
import executor import executor
from executor import * from executor import *
import trainer
from trainer import Trainer
from trainer import Event
import inferencer
from inferencer import Inferencer
import params
from params import Params
import io import io
import evaluator import evaluator
import initializer import initializer
@ -47,7 +57,8 @@ from parallel_executor import ParallelExecutor
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
trainer.__all__ + inferencer.__all__ + params.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',

@ -0,0 +1,28 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'Inferencer',
]
class Inferencer(object):
def __init__(self, network_func, params, place=None):
self.network_func = network_func
self.params = params
self.place = place
def infer(self, inputs):
pass

@ -0,0 +1,33 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import core
__all__ = [
'Params',
]
class Params(object):
def __init__(self, path=None):
self.scope = core.Scope()
if path:
self._load(path)
def _load(self, path):
pass
def save(self, path):
pass

@ -0,0 +1,46 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
__all__ = [
'Event',
'Trainer',
]
class Event(Enum):
BEGIN_EPOCH = 0
END_EPOCH = 1
BEGIN_STEP = 2
END_STEP = 3
def __init__(self):
self.step = 0
self.epoch = 0
self.type = Event.BEGIN_EPOCH
class Trainer(object):
def __init__(self, network_func, optimizer, params=None, place=None):
self.network_func = network_func
self.optimizer = optimizer
self.params = params
self.place = place
def train(self, reader, num_epochs, event_handler):
pass
def test(self, reader):
pass
Loading…
Cancel
Save