Add load/save method for Parameter

avx_docs
Yu Yang 8 years ago
parent 9601c2fcda
commit 8b833d5a8a

@ -551,6 +551,10 @@ public:
ParameterConfig* getConfig();
void setValueUpdated();
bool save(const std::string& filename) const;
bool load(const std::string& filename) const;
private:
static Parameter* createFromRawPtr(void* ptr);
static Parameter* createFromSharedPtr(void* ptr);

@ -70,3 +70,11 @@ ParameterConfig* Parameter::getConfig() {
size_t Parameter::getID() const { return m->getPtr()->getID(); }
void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }
bool Parameter::save(const std::string& filename) const {
return m->getPtr()->save(filename);
}
bool Parameter::load(const std::string& filename) const {
return m->getPtr()->load(filename);
}

@ -0,0 +1,6 @@
___fc_layer_0__.w0
___fc_layer_0__.wbias
_hidden1.w0
_hidden1.wbias
_hidden2.w0
_hidden2.wbias

@ -45,6 +45,7 @@ class TestGradientMachine(unittest.TestCase):
assert isinstance(val, swig_paddle.Vector)
arr = numpy.full((len(val), ), 0.1, dtype="float32")
val.copyFromNumpyArray(arr)
self.assertTrue(param.save(param.getName()))
param_config = param.getConfig().toProto()
assert isinstance(param_config,
paddle.proto.ParameterConfig_pb2.ParameterConfig)
@ -92,6 +93,9 @@ class TestGradientMachine(unittest.TestCase):
self.assertTrue(self.isCalled)
for param in machine.getParameters():
self.assertTrue(param.load(param.getName()))
def test_train_one_pass(self):
conf_file_path = './testTrainConfig.py'
trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile(

Loading…
Cancel
Save