|
|
|
@ -71,7 +71,7 @@ class FCPrimitiveFactory {
|
|
|
|
|
input_->set_data_handle(const_cast<T*>(in->data<T>()));
|
|
|
|
|
output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
if (out->format() == MKLDNNMemoryFormat::format_undef) {
|
|
|
|
|
auto output_format = output_->get_primitive_desc().desc().data.format;
|
|
|
|
|
auto output_format = platform::GetMKLDNNFormat(*output_);
|
|
|
|
|
out->set_format((MKLDNNMemoryFormat)output_format);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -199,8 +199,9 @@ class FCPrimitiveFactory {
|
|
|
|
|
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc();
|
|
|
|
|
auto buffer_size = dst_prim_desc.get_size();
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace(), buffer_size);
|
|
|
|
|
output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format);
|
|
|
|
|
return memory(dst_prim_desc, to_void_cast<T>(output_data));
|
|
|
|
|
memory dst_mem(dst_prim_desc, to_void_cast<T>(output_data));
|
|
|
|
|
output->set_format(platform::GetMKLDNNFormat(dst_mem));
|
|
|
|
|
return dst_mem;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input,
|
|
|
|
|