change ComputeINT8 to template version to remove checking dst_datatype code (#18756)

* change INT8 to template so that checking dst_dt with if-else could be removed. CI will be enabled after fixing reviews

* reverse user_residual_memory_p and user_bias_memory_p declaration scope
test=develop
DDDivano-patch-1
lidanqing 6 years ago committed by Tao Luo
parent d9e7b5b5e9
commit 9ecd8ee789

File diff suppressed because it is too large Load Diff

@ -20,7 +20,6 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
@ -82,22 +81,24 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_undef;
return mkldnn::memory::data_type::data_undef;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() {
return mkldnn::memory::f32;
return mkldnn::memory::data_type::f32;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int32_t>() {
return mkldnn::memory::data_type::s32;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int8_t>() {
return mkldnn::memory::s8;
return mkldnn::memory::data_type::s8;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::u8;
return mkldnn::memory::data_type::u8;
}
inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) {

@ -1160,18 +1160,24 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
scale_data, mask);
}
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,
bool fuse_brelu,
float fuse_brelu_threshold) const {
mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn, bool fuse_brelu,
float fuse_brelu_threshold,
const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
}
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f);
post_operations.append_sum(sum_scale);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
@ -1202,7 +1208,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, const bool fuse_residual_conn,
const bool fuse_brelu, const float fuse_brelu_threshold,
mkldnn::prop_kind fwd_prop_kind) {
mkldnn::prop_kind fwd_prop_kind,
const std::vector<float> output_shift_scale = {},
const float sum_scale = 1.0f) {
// Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
@ -1232,8 +1240,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
src, weights, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, output_shift_scale, sum_scale);
conv_pd_.reset(new typename forward_t::primitive_desc(
conv_desc, conv_attr, engine));
@ -1393,10 +1402,10 @@ template <typename T>
static void SetDstMemoryHandler(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::shared_ptr<mkldnn::memory>* dst_memory_p) {
std::shared_ptr<mkldnn::memory> dst_memory_p) {
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
(*dst_memory_p)->set_data_handle(to_void_cast<T>(output_data));
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
}
template <typename T>

Loading…
Cancel
Save