enable header format

revert-3824-remove_grad_op_type
tensor-tang 8 years ago
parent c307ee303b
commit 0945dc1b99

@ -48,7 +48,8 @@ Parameter::Parameter(const ParameterConfig& config, bool useGpu, bool doInit)
deviceId_(-1),
sharedCount_(0),
updateCounter_(0),
updated_(false) {
updated_(false),
headerFormat_(PARAM_FORMAT_ORIGINAL) {
setID(-1); /* capture uninitialized id */
if (useGpu_ && FLAGS_parallel_nn) {
/* gpu environment is specified by device property */
@ -285,7 +286,7 @@ bool Parameter::save(const std::string& filename) const {
bool Parameter::save(std::ostream& s) const {
CpuVector vec(*bufs_[PARAMETER_VALUE].get());
Header header;
header.version = kFormatVersion;
header.format = headerFormat_;
header.valueSize = sizeof(real);
header.size = getSize();
@ -344,8 +345,9 @@ bool Parameter::load(std::istream& s) {
Header header;
CHECK(s.read(reinterpret_cast<char*>(&header), sizeof(header)))
<< "Fail to read parameter " << getName();
CHECK_EQ(header.version, kFormatVersion) << "Incorrect format version: "
<< header.version;
CHECK(isHeaderFormatSupported(header.format)) << "Incorrect format version: "
<< header.format;
headerFormat_ = header.format;
CHECK_EQ(header.size, getSize())
<< "The size (" << header.size << ") in the file does not match the size "
<< "(" << getSize() << ") of the parameter: " << getName();

@ -34,6 +34,12 @@ limitations under the License. */
namespace paddle {
typedef enum {
PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format
PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi
PARAM_FORMAT_ITEMS, // the total format items numbers
} PARAM_FORMAT;
class SparsePrefetchRowCpuMatrix;
class Parameter;
@ -242,14 +248,30 @@ public:
/// Initialize the value to 0
void zeroMem();
static const int kFormatVersion = 0;
/// file header structure
struct Header {
int32_t version; // = 0, file format version
int32_t format; // = PARAM_FORMAT
uint32_t valueSize; // = sizeof(real)
uint64_t size; // = getSize()
};
/**
* @brief Is the header supported
*/
static bool isHeaderFormatSupported(int32_t fmt) {
return fmt < PARAM_FORMAT_ITEMS;
}
/**
* @brief Get the format in header
*/
int getHeaderFormat() { return headerFormat_; }
/**
* @brief Set the format in header
*/
void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; }
/**
* @brief Parameter Update Hook.
*
@ -321,6 +343,9 @@ protected:
bool updated_;
SparseFormat format_;
// The header format for saving or loading param
int32_t headerFormat_;
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;
public:

@ -1032,8 +1032,8 @@ void ParameterServer2::loadValueVector(const LoadValueRequest& request,
Parameter::Header header;
CHECK(fs.read(reinterpret_cast<char*>(&header), sizeof(header)))
<< "Fail to read parameters in pserver";
CHECK_EQ(header.version, Parameter::kFormatVersion)
<< "Incorrect format version: " << header.version;
CHECK(Parameter::isHeaderFormatSupported(header.format))
<< "Incorrect format version: " << header.format;
CHECK_EQ(header.size, (size_t)size_)
<< "The size (" << header.size << ") in the file does not match the size "
<< "(" << size_ << ") of the pserver: " << serverId_;
@ -1063,7 +1063,8 @@ void ParameterServer2::saveValueVector(const SaveValueRequest& request,
CpuVector& vec = vectors_[PARAMETER_APPLY] ? *vectors_[PARAMETER_APPLY]
: *vectors_[PARAMETER_VALUE];
Parameter::Header header;
header.version = Parameter::kFormatVersion;
// TODO(TJ): save param headerFormat_
header.format = PARAM_FORMAT_ORIGINAL;
header.valueSize = sizeof(real);
header.size = size_;

Loading…
Cancel
Save