|
|
|
@ -35,9 +35,17 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
typedef enum {
|
|
|
|
|
PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format
|
|
|
|
|
PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi
|
|
|
|
|
PARAM_FORMAT_ITEMS, // the total format items numbers
|
|
|
|
|
/// The paddle original basic format
|
|
|
|
|
PARAM_FORMAT_ORIGINAL = 0,
|
|
|
|
|
|
|
|
|
|
/// See mkldnn_memory_format_t in
|
|
|
|
|
/// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h
|
|
|
|
|
/// for a detailed description.
|
|
|
|
|
/// 2D weights tensor in the format (output channels, input channels).
|
|
|
|
|
PARAM_FORMAT_MKLDNN_OI,
|
|
|
|
|
|
|
|
|
|
/// The total format items numbers
|
|
|
|
|
PARAM_FORMAT_ITEMS,
|
|
|
|
|
} PARAM_FORMAT;
|
|
|
|
|
|
|
|
|
|
class SparsePrefetchRowCpuMatrix;
|
|
|
|
@ -256,19 +264,19 @@ public:
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Is the header supported
|
|
|
|
|
* @brief Is the header format supported.
|
|
|
|
|
*/
|
|
|
|
|
static bool isHeaderFormatSupported(int32_t fmt) {
|
|
|
|
|
return fmt < PARAM_FORMAT_ITEMS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Get the format in header
|
|
|
|
|
* @brief Get the format in header.
|
|
|
|
|
*/
|
|
|
|
|
int getHeaderFormat() { return headerFormat_; }
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Set the format in header
|
|
|
|
|
* @brief Set the format in header.
|
|
|
|
|
*/
|
|
|
|
|
void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; }
|
|
|
|
|
|
|
|
|
@ -343,7 +351,7 @@ protected:
|
|
|
|
|
bool updated_;
|
|
|
|
|
SparseFormat format_;
|
|
|
|
|
|
|
|
|
|
// The header format for saving or loading param
|
|
|
|
|
/// The header format for saving or loading param
|
|
|
|
|
int32_t headerFormat_;
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;
|
|
|
|
|