|
|
|
@ -262,12 +262,15 @@ void MKLDNNConvLayer::resetBwdWgtPD(
|
|
|
|
|
padR,
|
|
|
|
|
padKind);
|
|
|
|
|
pd.reset(new conv_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_));
|
|
|
|
|
CHECK(pd->src_primitive_desc() == inVal_->getPrimitiveDesc())
|
|
|
|
|
<< "primitive desc of in value should equal";
|
|
|
|
|
CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc())
|
|
|
|
|
<< "primitive desc of out grad should equal the out value";
|
|
|
|
|
CHECK(pd->diff_weights_primitive_desc() == wgtVal_->getPrimitiveDesc())
|
|
|
|
|
<< "primitive desc of weight grad should equal the weight value";
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(inVal_, pd->src_primitive_desc());
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(
|
|
|
|
|
outVal_,
|
|
|
|
|
pd->diff_dst_primitive_desc(),
|
|
|
|
|
"primitive desc of out value and grad should be equal");
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(
|
|
|
|
|
wgtVal_,
|
|
|
|
|
pd->diff_weights_primitive_desc(),
|
|
|
|
|
"primitive desc of weight value and grad should be equal");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNConvLayer::resetBwdDataPD(
|
|
|
|
@ -292,10 +295,14 @@ void MKLDNNConvLayer::resetBwdDataPD(
|
|
|
|
|
padR,
|
|
|
|
|
padding_kind::zero);
|
|
|
|
|
pd.reset(new conv_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_));
|
|
|
|
|
CHECK(pd->diff_src_primitive_desc() == inVal_->getPrimitiveDesc())
|
|
|
|
|
<< "primitive desc of in grad should equal the in value";
|
|
|
|
|
CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc())
|
|
|
|
|
<< "primitive desc of out grad should equal";
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(
|
|
|
|
|
inVal_,
|
|
|
|
|
pd->diff_src_primitive_desc(),
|
|
|
|
|
"primitive desc of in value and grad should be equal");
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(
|
|
|
|
|
outVal_,
|
|
|
|
|
pd->diff_dst_primitive_desc(),
|
|
|
|
|
"primitive desc of out value and grad should be equal");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MKLDNNConvLayer::resetBwdBuffers(
|
|
|
|
@ -310,17 +317,20 @@ void MKLDNNConvLayer::resetBwdBuffers(
|
|
|
|
|
|
|
|
|
|
resetWithMatrix(
|
|
|
|
|
wgt, weight_->getWGrad(), wgtPD->diff_weights_primitive_desc());
|
|
|
|
|
CHECK(wgtVal_ != nullptr &&
|
|
|
|
|
wgt->getPrimitiveDesc() == wgtVal_->getPrimitiveDesc())
|
|
|
|
|
<< "primitive desc of weight grad and value should be equal";
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(
|
|
|
|
|
wgtVal_,
|
|
|
|
|
wgt->getPrimitiveDesc(),
|
|
|
|
|
"primitive desc of weight grad and value should be equal");
|
|
|
|
|
|
|
|
|
|
bias = nullptr;
|
|
|
|
|
if (biases_ && biases_->getWGrad()) {
|
|
|
|
|
resetWithMatrix(
|
|
|
|
|
bias, biases_->getWGrad(), wgtPD->diff_bias_primitive_desc());
|
|
|
|
|
CHECK(bias && biasVal_ &&
|
|
|
|
|
bias->getPrimitiveDesc() == biasVal_->getPrimitiveDesc())
|
|
|
|
|
<< "primitive desc of bias grad should equal the bias value";
|
|
|
|
|
CHECK(bias);
|
|
|
|
|
CHECK_PRIMITIVE_DESC_EQ(
|
|
|
|
|
biasVal_,
|
|
|
|
|
bias->getPrimitiveDesc(),
|
|
|
|
|
"primitive desc of bias grad and value should be equal");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dataPD == nullptr) {
|
|
|
|
|