|
|
|
@ -34,6 +34,12 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
typedef enum {
|
|
|
|
|
PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format
|
|
|
|
|
PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi
|
|
|
|
|
PARAM_FORMAT_ITEMS, // the total format items numbers
|
|
|
|
|
} PARAM_FORMAT;
|
|
|
|
|
|
|
|
|
|
class SparsePrefetchRowCpuMatrix;
|
|
|
|
|
|
|
|
|
|
class Parameter;
|
|
|
|
@ -242,14 +248,30 @@ public:
|
|
|
|
|
/// Initialize the value to 0
|
|
|
|
|
void zeroMem();
|
|
|
|
|
|
|
|
|
|
static const int kFormatVersion = 0;
|
|
|
|
|
/// file header structure
|
|
|
|
|
struct Header {
|
|
|
|
|
int32_t version; // = 0, file format version
|
|
|
|
|
int32_t format; // = PARAM_FORMAT
|
|
|
|
|
uint32_t valueSize; // = sizeof(real)
|
|
|
|
|
uint64_t size; // = getSize()
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Is the header supported
|
|
|
|
|
*/
|
|
|
|
|
static bool isHeaderFormatSupported(int32_t fmt) {
|
|
|
|
|
return fmt < PARAM_FORMAT_ITEMS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Get the format in header
|
|
|
|
|
*/
|
|
|
|
|
int getHeaderFormat() { return headerFormat_; }
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Set the format in header
|
|
|
|
|
*/
|
|
|
|
|
void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; }
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Parameter Update Hook.
|
|
|
|
|
*
|
|
|
|
@ -321,6 +343,9 @@ protected:
|
|
|
|
|
bool updated_;
|
|
|
|
|
SparseFormat format_;
|
|
|
|
|
|
|
|
|
|
// The header format for saving or loading param
|
|
|
|
|
int32_t headerFormat_;
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|