|
|
|
@ -53,6 +53,18 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
key_ += "-BWD";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetDstMemorySize() const {
|
|
|
|
|
return conv_pd_->dst_primitive_desc().get_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetDiffWeightsMemorySize() const {
|
|
|
|
|
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetDiffSourceMemorySize() const {
|
|
|
|
|
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
@ -294,7 +306,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
const T* filter_data = filter->data<T>();
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
|
|
|
|
|
std::vector<int> weights_tz =
|
|
|
|
@ -354,6 +365,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto user_weights_memory_p = handler.AcquireWeightsMemory(
|
|
|
|
|
user_weights_md, to_void_cast<T>(filter_data));
|
|
|
|
|
|
|
|
|
|
T* output_data =
|
|
|
|
|
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
|
|
|
|
|
// create reorder primitive if the input format is not the preferred one
|
|
|
|
|
auto src_memory_p =
|
|
|
|
|
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
|
|
|
|
@ -476,13 +489,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
T* input_grad_data = nullptr;
|
|
|
|
|
T* filter_grad_data = nullptr;
|
|
|
|
|
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
|
|
|
|
|
std::vector<int> weights_tz =
|
|
|
|
|
paddle::framework::vectorize2int(filter->dims());
|
|
|
|
@ -568,6 +574,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.AcquireDiffDstMemoryFromWeightsPrimitive(
|
|
|
|
|
user_diff_dst_memory_p, pipeline);
|
|
|
|
|
|
|
|
|
|
const size_t size = handler.GetDiffWeightsMemorySize();
|
|
|
|
|
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
|
|
|
|
|
|
|
|
|
|
auto diff_weights_memory_p =
|
|
|
|
|
handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
|
|
|
|
|
reinterpret_cast<void*>(filter_grad_data));
|
|
|
|
@ -590,6 +599,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
|
|
|
|
|
pipeline);
|
|
|
|
|
|
|
|
|
|
const size_t size = handler.GetDiffSourceMemorySize();
|
|
|
|
|
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
|
|
|
|
|
|
|
|
|
|
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
|
|
|
|
|
reinterpret_cast<void*>(input_grad_data));
|
|
|
|
|
|
|
|
|
|