Merge pull request #15304 from tensor-tang/fuse/second_order_mul_sub

Fuse/second order mul sub and fuse repeated fc relu
recover_files
tensor-tang 6 years ago committed by GitHub
commit a7fc3d42a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,6 +43,8 @@ pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference)
pass_library(repeated_fc_relu_fuse_pass inference)
pass_library(squared_mat_sub_fuse_pass inference)
pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,41 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/**
* Fuse Repeated FC Relu
*/
class RepeatedFCReluFusePass : public FusePassBase {
public:
virtual ~RepeatedFCReluFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"repeated_fc_relu_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -129,7 +129,8 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
return concat_out_var;
}
int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) {
static int BuildFusion(Graph* graph, const std::string& name_scope,
int num_inputs) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs);

File diff suppressed because it is too large Load Diff

@ -0,0 +1,41 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/**
* Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
*/
class SquaredMatSubFusePass : public FusePassBase {
public:
virtual ~SquaredMatSubFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"squared_mat_sub_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -98,6 +98,8 @@ class CpuPassStrategy : public PassStrategy {
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //

@ -37,15 +37,21 @@ function(inference_analysis_api_test_with_refer_result target install_dir filena
--refer_result=${install_dir}/result.txt)
endfunction()
# RNN1
if(NOT APPLE AND WITH_MKLML)
# RNN1
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz")
inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc SERIAL)
# seq_pool1
set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool")
download_model_and_data(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_seq_pool1 ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_tester.cc SERIAL)
else()
# TODO: fix this test on MACOS and OPENBLAS, the reason is that
# fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_rnn1")
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1")
endif()
# RNN2
@ -90,11 +96,6 @@ set(SEQ_CONV1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_conv1")
download_model_and_data(${SEQ_CONV1_INSTALL_DIR} "seq_conv1_model.tar.gz" "seq_conv1_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} analyzer_seq_conv1_tester.cc)
# seq_pool1
set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool")
download_model_and_data(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_seq_pool1 ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_tester.cc)
# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR})

@ -21,6 +21,12 @@ namespace paddle {
namespace inference {
namespace analysis {
// diff: similarity_norm.tmp_0, for speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// for diff: 154, for speed 111
constexpr int num_slots = 154;
struct OneSlotInBatch {
std::string name;
std::vector<std::vector<float>> data;
@ -41,7 +47,6 @@ struct DataRecord {
void Load(const std::string &path) {
std::ifstream file(path);
constexpr int num_slots = 154;
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
@ -187,11 +192,15 @@ void analysis_fuse_statis(bool use_zerocopy) {
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
ASSERT_TRUE(fuse_statis.count("squared_mat_sub_fuse"));
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 2);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 195);
EXPECT_EQ(num_ops, 171);
}
// Check the fuse status
@ -214,9 +223,6 @@ void PrepareZeroCopyInputs(
}
}
// diff: similarity_norm.tmp_0, // speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// return the output values
std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config;
@ -322,7 +328,9 @@ TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
native_outputs.front().data.length());
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (size_t i = 0; i < zerocopy_output.size(); ++i) {
EXPECT_NEAR(zerocopy_output[i], native_data[i], 1e-3);
EXPECT_LT(
std::fabs((zerocopy_output[i] - native_data[i]) / zerocopy_output[i]),
1e-3);
}
}

@ -0,0 +1,149 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle {
namespace operators {
void FusionRepeatedFCReluOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionRepeatedFCReluOp should not be null.");
auto sz = ctx->Inputs("W").size();
PADDLE_ENFORCE_GT(
sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1.");
PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz,
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
"equal to inputs size.");
PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1,
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
"be equal to inputs size -1.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionRepeatedFCReluOp should not be null.");
auto i_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(i_dims.size(), 2UL, "Input shape size should be 2");
auto w_dims = ctx->GetInputsDim("W");
auto b_dims = ctx->GetInputsDim("Bias");
PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(),
"Shape size of weight and bias should be equal");
PADDLE_ENFORCE_EQ(w_dims.size(), sz,
"Shape size of weight and bias should be equal");
PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0],
"inpute width should be equal with weight height");
for (size_t i = 1; i < sz; ++i) {
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2UL,
"Every weight shape size should be 2.");
PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1],
"The length of Bias must be equal with w_dims[1].");
}
ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]});
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
ctx.GetPlace());
}
void FusionRepeatedFCReluOpMaker::Make() {
AddInput("X", "(LoDTensor) Input tensors of this operator.");
AddInput("W", "(Tensor) The weight tensors of this operator.").AsDuplicable();
AddInput("Bias", "(Tensor) The bias tensors of this operator.")
.AsDuplicable();
AddOutput("ReluOut", "(Tensor) The output tensor of each relu operator.")
.AsDuplicable()
.AsIntermediate();
AddOutput("Out", "(LoDTensor) Output tensor of this operator.");
AddComment(R"DOC(
Fusion Repeated FC with Relu Operator.
)DOC");
}
template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y, int m, int n,
int k) {
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
auto addbias_relu =
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(n);
matmul(x, w, y, m, n, k);
T* dst = y;
for (int i = 0; i < m; ++i) {
addbias_relu(b, dst, dst, n);
dst += n;
}
}
template <typename T>
class FusionRepeatedFCReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto in = ctx.Input<Tensor>("X");
auto weights = ctx.MultiInput<Tensor>("W");
auto biases = ctx.MultiInput<Tensor>("Bias");
auto relus = ctx.MultiOutput<Tensor>("ReluOut");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
int weight_sz = static_cast<int>(weights.size());
auto i_dims = in->dims();
auto w_dims = weights[0]->dims();
int m = i_dims[0];
int n = w_dims[1];
int k = w_dims[0];
relus[0]->Resize({m, n});
fc_relu(in->data<T>(), weights[0]->data<T>(), biases[0]->data<T>(),
relus[0]->mutable_data<T>(place), m, n, k);
for (int i = 1; i < weight_sz - 1; ++i) {
auto i_dims = relus[i - 1]->dims();
auto w_dims = weights[i]->dims();
int m = i_dims[0];
int n = w_dims[1];
int k = w_dims[0];
relus[i]->Resize({m, n});
fc_relu(relus[i - 1]->data<T>(), weights[i]->data<T>(),
biases[i]->data<T>(), relus[i]->mutable_data<T>(place), m, n, k);
}
auto i_dims_last = relus[weight_sz - 2]->dims();
auto w_dims_last = weights[weight_sz - 1]->dims();
m = i_dims_last[0];
n = w_dims_last[1];
k = w_dims_last[0];
fc_relu(relus[weight_sz - 2]->data<T>(), weights[weight_sz - 1]->data<T>(),
biases[weight_sz - 1]->data<T>(), out->mutable_data<T>(place), m, n,
k);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_repeated_fc_relu, ops::FusionRepeatedFCReluOp,
ops::FusionRepeatedFCReluOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_repeated_fc_relu,
ops::FusionRepeatedFCReluKernel<float>,
ops::FusionRepeatedFCReluKernel<double>);

@ -0,0 +1,41 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionRepeatedFCReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionRepeatedFCReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,137 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle {
namespace operators {
void FusionSquaredMatSubOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredX"),
"Output(SquaredX) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredY"),
"Output(SquaredY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredXY"),
"Output(SquaredXY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSquaredMatSubOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Input tensors dims size should be equal.");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input tensors should be a Matrix.");
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply.");
ctx->SetOutputDim("SquaredX", x_dims);
ctx->SetOutputDim("SquaredY", y_dims);
ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]});
ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]});
}
framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
ctx.GetPlace());
}
void FusionSquaredMatSubOpMaker::Make() {
AddInput("X", "(Tensor) Input Mat A of this operator.");
AddInput("Y", "(Tensor) Input Mat B of this operator.");
AddOutput("SquaredX", "(Tensor) Squared X.").AsIntermediate();
AddOutput("SquaredY", "(Tensor) Squared Y.").AsIntermediate();
AddOutput("SquaredXY", "(Tensor) Squared X*Y.").AsIntermediate();
AddOutput("Out", "(Tensor) Output tensor of concat operator.");
AddAttr<float>("scalar", "The scalar on output matrix.").SetDefault(1.f);
AddComment(R"DOC(
Fusion Squared Matrix and substrct operator.
( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar
)DOC");
}
template <typename T>
class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<Tensor>("X");
auto y = ctx.Input<Tensor>("Y");
auto* squared_x = ctx.Output<Tensor>("SquaredX");
auto* squared_y = ctx.Output<Tensor>("SquaredY");
auto* squared_xy = ctx.Output<Tensor>("SquaredXY");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
T scalar = static_cast<T>(ctx.Attr<float>("scalar"));
auto x_dims = x->dims();
auto y_dims = y->dims();
int m = x_dims[0];
int k = x_dims[1];
int n = y_dims[1];
int o_numel = m * n;
auto vsquare_x =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(m * k);
auto vsquare_y =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(k * n);
auto vsquare_xy =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
auto vsub =
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
auto vscal =
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
const T* x_data = x->data<T>();
const T* y_data = y->data<T>();
T* squared_x_data = squared_x->mutable_data<T>(place);
T* squared_y_data = squared_y->mutable_data<T>(place);
T* squared_xy_data = squared_xy->mutable_data<T>(place);
T* o_data = out->mutable_data<T>(place);
matmul(x_data, y_data, squared_xy_data, m, n, k);
vsquare_xy(squared_xy_data, squared_xy_data, o_numel);
vsquare_x(x_data, squared_x_data, m * k);
vsquare_y(y_data, squared_y_data, k * n);
matmul(squared_x_data, squared_y_data, o_data, m, n, k);
vsub(squared_xy_data, o_data, o_data, o_numel);
vscal(&scalar, o_data, o_data, o_numel);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_squared_mat_sub, ops::FusionSquaredMatSubOp,
ops::FusionSquaredMatSubOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_squared_mat_sub,
ops::FusionSquaredMatSubKernel<float>,
ops::FusionSquaredMatSubKernel<double>);

@ -0,0 +1,42 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
// ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
class FusionSquaredMatSubOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSquaredMatSubOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle

@ -210,6 +210,24 @@ void BenchSeqPoolKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) {
for (int k : TestSizes()) {
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
RandomVec<T>(k * n, b.data(), -2.f, 2.f);
const T* a_data = a.data();
const T* b_data = b.data();
T* c_data = c.data();
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(k, a_data, b_data,
c_data, m, n, k);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
@ -236,6 +254,7 @@ int main(int argc, char* argv[]) {
// xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVSquare, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
@ -251,4 +270,7 @@ int main(int argc, char* argv[]) {
// seq pool function
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
// matmul
BenchMatMulKernel<jit::kMatMul, T, PlaceType>();
}

@ -11,11 +11,12 @@ endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(kVSub)
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVSquare)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)

@ -91,6 +91,7 @@ void VActJitCode::genCode() {
}
DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VSquare);
DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid);
@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const {
8 /* average bytes for each instruction */;
}
size_t VSquareCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const {
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);

@ -75,6 +75,12 @@ class VActFunc : public JitCode {
vmaxps(dst, src, zero);
}
// compute SQUARE with ymm, xmm
template <typename JMM>
void square_jmm(JMM& dst, JMM& src) { // NOLINT
vmulps(dst, src, src);
}
// compute EXP with ymm, xmm
template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
@ -228,6 +234,9 @@ class VActFunc : public JitCode {
case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15);
break;
case operand_type::SQUARE:
square_jmm<JMM>(dst, src);
break;
case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
@ -254,7 +263,7 @@ class VActJitCode : public VActFunc {
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
type_ == operand_type::IDENTITY)) {
type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
@ -266,6 +275,9 @@ class VActJitCode : public VActFunc {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::SQUARE:
base += "_Square";
break;
case operand_type::EXP:
base += "_Exp";
break;
@ -306,6 +318,7 @@ class VActJitCode : public VActFunc {
};
DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
DECLARE_ACT_JITCODE(VSquare, operand_type::SQUARE);
DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);

@ -43,6 +43,8 @@ void VXXJitCode::genCode() {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::SUB) {
vsubps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
@ -85,6 +87,9 @@ void VXXJitCode::genCode() {
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::SUB:
vsubps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);

@ -34,7 +34,8 @@ class VXXJitCode : public JitCode {
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD ||
type_ == operand_type::SUB)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
@ -51,6 +52,8 @@ class VXXJitCode : public JitCode {
base += "_Mul";
} else if (type_ == operand_type::ADD) {
base += "_Add";
} else if (type_ == operand_type::SUB) {
base += "_SUB";
}
if (scalar_index_ == 2) {
base += "_Scalar";

@ -51,6 +51,7 @@ typedef enum {
SUB,
RELU,
EXP,
SQUARE,
SIGMOID,
TANH,
IDENTITY

@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVRelu);
ONE_CASE(kVIdentity);
ONE_CASE(kVExp);
ONE_CASE(kVSquare);
ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt);
@ -47,6 +48,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
ONE_CASE(kMatMul);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";

@ -30,6 +30,7 @@ typedef enum {
kVAddBias,
kVRelu,
kVIdentity,
kVSquare,
kVExp,
kVSigmoid,
kVTanh,
@ -42,6 +43,7 @@ typedef enum {
kLayerNorm,
kNCHW16CMulNC,
kSeqPool,
kMatMul,
} KernelType;
typedef enum {
@ -135,6 +137,13 @@ struct SeqPoolTuples {
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
};
template <typename T>
struct MatMulTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int, int);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;

@ -3,10 +3,12 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)

@ -24,6 +24,20 @@ namespace jit {
namespace more {
namespace mkl {
template <>
void MatMul<float>(const float* a, const float* b, float* c, int m, int n,
int k) {
platform::dynload::cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.f, a, k, b, n, 0.f, c, n);
}
template <>
void MatMul<double>(const double* a, const double* b, double* c, int m, int n,
int k) {
platform::dynload::cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.0, a, k, b, n, 0.0, c, n);
}
template <>
void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
@ -72,6 +86,16 @@ void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
template <>
void VSquare<float>(const float* x, float* y, int n) {
platform::dynload::vsSqr(n, x, y);
}
template <>
void VSquare<double>(const double* x, double* y, int n) {
platform::dynload::vdSqr(n, x, y);
}
template <>
void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1);
@ -93,6 +117,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool MatMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
@ -113,6 +142,11 @@ bool VExpKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSquareKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7;
@ -139,12 +173,14 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(MatMul);
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
AWALYS_USE_ME_WITH_DOUBLE(VSquare);
#undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl
@ -159,10 +195,12 @@ namespace mkl = paddle::operators::jit::more::mkl;
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);

@ -24,6 +24,9 @@ namespace jit {
namespace more {
namespace mkl {
template <typename T>
void MatMul(const T* a, const T* b, T* c, int m, int n, int k);
template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
@ -36,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n);
template <typename T>
void VExp(const T* x, T* y, int n);
template <typename T>
void VSquare(const T* x, T* y, int n);
template <typename T>
void VCopy(const T* x, T* y, int n);
@ -93,6 +99,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
const char* ImplType() const override { return "MKL"; } \
}
// ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples);
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
@ -104,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples);
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save