refine comments

revert-3824-remove_grad_op_type
tensor-tang 8 years ago
parent 635b8672b2
commit 07d16e3e13

@ -345,10 +345,10 @@ void MKLDNNTester::run(const TestConfig& dnn,
return;
}
// After run some iters, the mkldnn weight has been stored in dnnLayer
// and we can also get the mkldnn weight paramter header format
// Weight param should always be index 0 (and bias index 1).
// TODO(TJ): should also considerate mean and var format when batchnorm ready
// After run some iterations, the mkldnn weight has been stored in dnnLayer
// and we can also get the mkldnn weight parameter header format.
// Weight parameter should always be index 0 (and bias index 1).
// TODO(TJ): should also consider mean and var format when batchnorm ready
int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat();
int refWgtFmt = parameters_[REF][0]->getHeaderFormat();
if (dnnWgtFmt == refWgtFmt) {

@ -35,9 +35,17 @@ limitations under the License. */
namespace paddle {
typedef enum {
PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format
PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi
PARAM_FORMAT_ITEMS, // the total format items numbers
/// The paddle original basic format
PARAM_FORMAT_ORIGINAL = 0,
/// See mkldnn_memory_format_t in
/// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h
/// for a detailed description.
/// 2D weights tensor in the format (output channels, input channels).
PARAM_FORMAT_MKLDNN_OI,
/// The total format items numbers
PARAM_FORMAT_ITEMS,
} PARAM_FORMAT;
class SparsePrefetchRowCpuMatrix;
@ -256,19 +264,19 @@ public:
};
/**
* @brief Is the header supported
* @brief Is the header format supported.
*/
static bool isHeaderFormatSupported(int32_t fmt) {
return fmt < PARAM_FORMAT_ITEMS;
}
/**
* @brief Get the format in header
* @brief Get the format in header.
*/
int getHeaderFormat() { return headerFormat_; }
/**
* @brief Set the format in header
* @brief Set the format in header.
*/
void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; }
@ -343,7 +351,7 @@ protected:
bool updated_;
SparseFormat format_;
// The header format for saving or loading param
/// The header format for saving or loading param
int32_t headerFormat_;
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;

Loading…
Cancel
Save