|
|
|
@ -53,15 +53,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
key_ += "-BWD";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetDstMemorySize() {
|
|
|
|
|
size_t GetDstMemorySize() const {
|
|
|
|
|
return conv_pd_->dst_primitive_desc().get_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetDiffWeightsMemorySize() {
|
|
|
|
|
size_t GetDiffWeightsMemorySize() const {
|
|
|
|
|
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetDiffSourceMemorySize() {
|
|
|
|
|
size_t GetDiffSourceMemorySize() const {
|
|
|
|
|
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -491,7 +491,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.AcquireDiffDstMemoryFromWeightsPrimitive(
|
|
|
|
|
user_diff_dst_memory_p, pipeline);
|
|
|
|
|
|
|
|
|
|
size_t size = handler.GetDiffWeightsMemorySize();
|
|
|
|
|
const size_t size = handler.GetDiffWeightsMemorySize();
|
|
|
|
|
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
|
|
|
|
|
|
|
|
|
|
auto diff_weights_memory_p =
|
|
|
|
@ -516,7 +516,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
|
|
|
|
|
pipeline);
|
|
|
|
|
|
|
|
|
|
size_t size = handler.GetDiffSourceMemorySize();
|
|
|
|
|
const size_t size = handler.GetDiffSourceMemorySize();
|
|
|
|
|
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
|
|
|
|
|
|
|
|
|
|
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
|
|
|
|
|