|
|
@ -268,9 +268,9 @@ protected:
|
|
|
|
/**
|
|
|
|
/**
|
|
|
|
* reset the output grad matrix from primitive desc.
|
|
|
|
* reset the output grad matrix from primitive desc.
|
|
|
|
* and reset the merge grad primitive if needed.
|
|
|
|
* and reset the merge grad primitive if needed.
|
|
|
|
* note: when this layer have serval output,
|
|
|
|
* note: when this layer has serval outputs,
|
|
|
|
* do not support mixing with cpu device,
|
|
|
|
* it could not be mixed with cpu device,
|
|
|
|
* because can not get memory desc from cpu device.
|
|
|
|
* since it can not get memory desc from cpu device.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
virtual void resetOutGrad(MKLDNNMatrixPtr& out,
|
|
|
|
virtual void resetOutGrad(MKLDNNMatrixPtr& out,
|
|
|
|
mkldnn::memory::primitive_desc pd) {
|
|
|
|
mkldnn::memory::primitive_desc pd) {
|
|
|
@ -281,7 +281,7 @@ protected:
|
|
|
|
if (outputMap_.size() <= 1) {
|
|
|
|
if (outputMap_.size() <= 1) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::vector<double> scales;
|
|
|
|
std::vector<double> scales(outputMap_.size(), 1.0);
|
|
|
|
std::vector<mkldnn::memory::primitive_desc> srcPDs;
|
|
|
|
std::vector<mkldnn::memory::primitive_desc> srcPDs;
|
|
|
|
std::vector<mkldnn::primitive::at> srcs;
|
|
|
|
std::vector<mkldnn::primitive::at> srcs;
|
|
|
|
for (auto it = outputMap_.begin(); it != outputMap_.end(); ++it) {
|
|
|
|
for (auto it = outputMap_.begin(); it != outputMap_.end(); ++it) {
|
|
|
@ -297,7 +297,6 @@ protected:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
srcPDs.push_back(src->getPrimitiveDesc());
|
|
|
|
srcPDs.push_back(src->getPrimitiveDesc());
|
|
|
|
srcs.push_back(*src);
|
|
|
|
srcs.push_back(*src);
|
|
|
|
scales.push_back(1.0);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): remove me when mkldnn sum support different formats
|
|
|
|
// TODO(TJ): remove me when mkldnn sum support different formats
|
|
|
|