|
|
|
@ -55,7 +55,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) {
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
|
auto src_pd = conv_bwd_weights_pd_->src_primitive_desc();
|
|
|
|
|
auto user_pd = user_memory_p->get_primitive_desc();
|
|
|
|
|
return this->AcquireMemory(src_pd, user_pd, user_memory_p,
|
|
|
|
@ -64,7 +64,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) {
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
|
auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_primitive_desc();
|
|
|
|
|
auto user_pd = user_memory_p->get_primitive_desc();
|
|
|
|
|
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
|
|
|
|
@ -80,7 +80,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) {
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
|
auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_primitive_desc();
|
|
|
|
|
auto user_pd = user_memory_p->get_primitive_desc();
|
|
|
|
|
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
|
|
|
|
@ -89,7 +89,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) {
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
|
auto weights_pd = conv_bwd_data_pd_->weights_primitive_desc();
|
|
|
|
|
auto user_pd = user_weights_memory_p->get_primitive_desc();
|
|
|
|
|
return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p,
|
|
|
|
@ -109,7 +109,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) {
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
|
auto src_pd = conv_pd_->src_primitive_desc();
|
|
|
|
|
auto user_pd = user_memory_p->get_primitive_desc();
|
|
|
|
|
return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
|
|
|
|
@ -118,7 +118,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) {
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
|
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
|
|
|
|
|
auto weights_pd = conv_pd_->weights_primitive_desc();
|
|
|
|
|
return this->AcquireMemory(weights_pd, user_weights_pd,
|
|
|
|
@ -197,12 +197,12 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
// Generate keys for storing/retriving primitives for this operator
|
|
|
|
|
// TODO(jczaja): Make hashing function more optimial
|
|
|
|
|
static std::string GetHash(memory::dims& input_dims,
|
|
|
|
|
memory::dims& weights_dims,
|
|
|
|
|
std::vector<int>& strides,
|
|
|
|
|
std::vector<int>& paddings,
|
|
|
|
|
std::vector<int>& dilations, int groups,
|
|
|
|
|
const std::string& suffix) {
|
|
|
|
|
static std::string GetHash(memory::dims& input_dims, // NOLINT
|
|
|
|
|
memory::dims& weights_dims, // NOLINT
|
|
|
|
|
std::vector<int>& strides, // NOLINT
|
|
|
|
|
std::vector<int>& paddings, // NOLINT
|
|
|
|
|
std::vector<int>& dilations, // NOLINT
|
|
|
|
|
int groups, const std::string& suffix) {
|
|
|
|
|
return dims2str(input_dims) + dims2str(weights_dims) + dims2str(strides) +
|
|
|
|
|
dims2str(paddings) + dims2str(dilations) + std::to_string(groups) +
|
|
|
|
|
suffix;
|
|
|
|
|