Merge pull request #15304 from tensor-tang/fuse/second_order_mul_sub
Fuse/second order mul sub and fuse repeated fc relurecover_files
commit
a7fc3d42a0
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fuse Repeated FC Relu
|
||||||
|
*/
|
||||||
|
class RepeatedFCReluFusePass : public FusePassBase {
|
||||||
|
public:
|
||||||
|
virtual ~RepeatedFCReluFusePass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||||
|
|
||||||
|
const std::string name_scope_{"repeated_fc_relu_fuse"};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
|
||||||
|
*/
|
||||||
|
class SquaredMatSubFusePass : public FusePassBase {
|
||||||
|
public:
|
||||||
|
virtual ~SquaredMatSubFusePass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||||
|
|
||||||
|
const std::string name_scope_{"squared_mat_sub_fuse"};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,149 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/operators/jit/kernels.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
void FusionRepeatedFCReluOp::InferShape(
|
||||||
|
framework::InferShapeContext* ctx) const {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||||
|
"Input(X) of FusionRepeatedFCReluOp should not be null.");
|
||||||
|
auto sz = ctx->Inputs("W").size();
|
||||||
|
PADDLE_ENFORCE_GT(
|
||||||
|
sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1.");
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz,
|
||||||
|
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
|
||||||
|
"equal to inputs size.");
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1,
|
||||||
|
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
|
||||||
|
"be equal to inputs size -1.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||||
|
"Output(Out) of FusionRepeatedFCReluOp should not be null.");
|
||||||
|
|
||||||
|
auto i_dims = ctx->GetInputDim("X");
|
||||||
|
PADDLE_ENFORCE_EQ(i_dims.size(), 2UL, "Input shape size should be 2");
|
||||||
|
|
||||||
|
auto w_dims = ctx->GetInputsDim("W");
|
||||||
|
auto b_dims = ctx->GetInputsDim("Bias");
|
||||||
|
PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(),
|
||||||
|
"Shape size of weight and bias should be equal");
|
||||||
|
PADDLE_ENFORCE_EQ(w_dims.size(), sz,
|
||||||
|
"Shape size of weight and bias should be equal");
|
||||||
|
PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0],
|
||||||
|
"inpute width should be equal with weight height");
|
||||||
|
|
||||||
|
for (size_t i = 1; i < sz; ++i) {
|
||||||
|
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2UL,
|
||||||
|
"Every weight shape size should be 2.");
|
||||||
|
PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1],
|
||||||
|
"The length of Bias must be equal with w_dims[1].");
|
||||||
|
}
|
||||||
|
ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]});
|
||||||
|
ctx->ShareLoD("X", /*->*/ "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const {
|
||||||
|
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
|
||||||
|
ctx.GetPlace());
|
||||||
|
}
|
||||||
|
|
||||||
|
void FusionRepeatedFCReluOpMaker::Make() {
|
||||||
|
AddInput("X", "(LoDTensor) Input tensors of this operator.");
|
||||||
|
AddInput("W", "(Tensor) The weight tensors of this operator.").AsDuplicable();
|
||||||
|
AddInput("Bias", "(Tensor) The bias tensors of this operator.")
|
||||||
|
.AsDuplicable();
|
||||||
|
AddOutput("ReluOut", "(Tensor) The output tensor of each relu operator.")
|
||||||
|
.AsDuplicable()
|
||||||
|
.AsIntermediate();
|
||||||
|
AddOutput("Out", "(LoDTensor) Output tensor of this operator.");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Fusion Repeated FC with Relu Operator.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void fc_relu(const T* x, const T* w, const T* b, T* y, int m, int n,
|
||||||
|
int k) {
|
||||||
|
auto matmul =
|
||||||
|
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
|
||||||
|
auto addbias_relu =
|
||||||
|
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(n);
|
||||||
|
matmul(x, w, y, m, n, k);
|
||||||
|
T* dst = y;
|
||||||
|
for (int i = 0; i < m; ++i) {
|
||||||
|
addbias_relu(b, dst, dst, n);
|
||||||
|
dst += n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FusionRepeatedFCReluKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto in = ctx.Input<Tensor>("X");
|
||||||
|
auto weights = ctx.MultiInput<Tensor>("W");
|
||||||
|
auto biases = ctx.MultiInput<Tensor>("Bias");
|
||||||
|
auto relus = ctx.MultiOutput<Tensor>("ReluOut");
|
||||||
|
auto* out = ctx.Output<Tensor>("Out");
|
||||||
|
auto place = ctx.GetPlace();
|
||||||
|
int weight_sz = static_cast<int>(weights.size());
|
||||||
|
|
||||||
|
auto i_dims = in->dims();
|
||||||
|
auto w_dims = weights[0]->dims();
|
||||||
|
int m = i_dims[0];
|
||||||
|
int n = w_dims[1];
|
||||||
|
int k = w_dims[0];
|
||||||
|
relus[0]->Resize({m, n});
|
||||||
|
fc_relu(in->data<T>(), weights[0]->data<T>(), biases[0]->data<T>(),
|
||||||
|
relus[0]->mutable_data<T>(place), m, n, k);
|
||||||
|
|
||||||
|
for (int i = 1; i < weight_sz - 1; ++i) {
|
||||||
|
auto i_dims = relus[i - 1]->dims();
|
||||||
|
auto w_dims = weights[i]->dims();
|
||||||
|
int m = i_dims[0];
|
||||||
|
int n = w_dims[1];
|
||||||
|
int k = w_dims[0];
|
||||||
|
relus[i]->Resize({m, n});
|
||||||
|
fc_relu(relus[i - 1]->data<T>(), weights[i]->data<T>(),
|
||||||
|
biases[i]->data<T>(), relus[i]->mutable_data<T>(place), m, n, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto i_dims_last = relus[weight_sz - 2]->dims();
|
||||||
|
auto w_dims_last = weights[weight_sz - 1]->dims();
|
||||||
|
m = i_dims_last[0];
|
||||||
|
n = w_dims_last[1];
|
||||||
|
k = w_dims_last[0];
|
||||||
|
fc_relu(relus[weight_sz - 2]->data<T>(), weights[weight_sz - 1]->data<T>(),
|
||||||
|
biases[weight_sz - 1]->data<T>(), out->mutable_data<T>(place), m, n,
|
||||||
|
k);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(fusion_repeated_fc_relu, ops::FusionRepeatedFCReluOp,
|
||||||
|
ops::FusionRepeatedFCReluOpMaker,
|
||||||
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(fusion_repeated_fc_relu,
|
||||||
|
ops::FusionRepeatedFCReluKernel<float>,
|
||||||
|
ops::FusionRepeatedFCReluKernel<double>);
|
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using LoDTensor = framework::LoDTensor;
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
class FusionRepeatedFCReluOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FusionRepeatedFCReluOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,137 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/operators/jit/kernels.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
void FusionSquaredMatSubOp::InferShape(
|
||||||
|
framework::InferShapeContext* ctx) const {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||||
|
"Input(X) of FusionSquaredMatSubOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
||||||
|
"Input(Y) of FusionSquaredMatSubOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
ctx->HasOutput("SquaredX"),
|
||||||
|
"Output(SquaredX) of FusionSquaredMatSubOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
ctx->HasOutput("SquaredY"),
|
||||||
|
"Output(SquaredY) of FusionSquaredMatSubOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
ctx->HasOutput("SquaredXY"),
|
||||||
|
"Output(SquaredXY) of FusionSquaredMatSubOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||||
|
"Output(Out) of FusionSquaredMatSubOp should not be null.");
|
||||||
|
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
auto y_dims = ctx->GetInputDim("Y");
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
|
||||||
|
"Input tensors dims size should be equal.");
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input tensors should be a Matrix.");
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply.");
|
||||||
|
|
||||||
|
ctx->SetOutputDim("SquaredX", x_dims);
|
||||||
|
ctx->SetOutputDim("SquaredY", y_dims);
|
||||||
|
ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]});
|
||||||
|
ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]});
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const {
|
||||||
|
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
|
||||||
|
ctx.GetPlace());
|
||||||
|
}
|
||||||
|
|
||||||
|
void FusionSquaredMatSubOpMaker::Make() {
|
||||||
|
AddInput("X", "(Tensor) Input Mat A of this operator.");
|
||||||
|
AddInput("Y", "(Tensor) Input Mat B of this operator.");
|
||||||
|
AddOutput("SquaredX", "(Tensor) Squared X.").AsIntermediate();
|
||||||
|
AddOutput("SquaredY", "(Tensor) Squared Y.").AsIntermediate();
|
||||||
|
AddOutput("SquaredXY", "(Tensor) Squared X*Y.").AsIntermediate();
|
||||||
|
AddOutput("Out", "(Tensor) Output tensor of concat operator.");
|
||||||
|
AddAttr<float>("scalar", "The scalar on output matrix.").SetDefault(1.f);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Fusion Squared Matrix and substrct operator.
|
||||||
|
|
||||||
|
( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto x = ctx.Input<Tensor>("X");
|
||||||
|
auto y = ctx.Input<Tensor>("Y");
|
||||||
|
auto* squared_x = ctx.Output<Tensor>("SquaredX");
|
||||||
|
auto* squared_y = ctx.Output<Tensor>("SquaredY");
|
||||||
|
auto* squared_xy = ctx.Output<Tensor>("SquaredXY");
|
||||||
|
auto* out = ctx.Output<Tensor>("Out");
|
||||||
|
auto place = ctx.GetPlace();
|
||||||
|
T scalar = static_cast<T>(ctx.Attr<float>("scalar"));
|
||||||
|
|
||||||
|
auto x_dims = x->dims();
|
||||||
|
auto y_dims = y->dims();
|
||||||
|
int m = x_dims[0];
|
||||||
|
int k = x_dims[1];
|
||||||
|
int n = y_dims[1];
|
||||||
|
int o_numel = m * n;
|
||||||
|
|
||||||
|
auto vsquare_x =
|
||||||
|
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(m * k);
|
||||||
|
auto vsquare_y =
|
||||||
|
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(k * n);
|
||||||
|
auto vsquare_xy =
|
||||||
|
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
|
||||||
|
auto vsub =
|
||||||
|
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
|
||||||
|
auto vscal =
|
||||||
|
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
|
||||||
|
auto matmul =
|
||||||
|
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
|
||||||
|
|
||||||
|
const T* x_data = x->data<T>();
|
||||||
|
const T* y_data = y->data<T>();
|
||||||
|
T* squared_x_data = squared_x->mutable_data<T>(place);
|
||||||
|
T* squared_y_data = squared_y->mutable_data<T>(place);
|
||||||
|
T* squared_xy_data = squared_xy->mutable_data<T>(place);
|
||||||
|
T* o_data = out->mutable_data<T>(place);
|
||||||
|
|
||||||
|
matmul(x_data, y_data, squared_xy_data, m, n, k);
|
||||||
|
vsquare_xy(squared_xy_data, squared_xy_data, o_numel);
|
||||||
|
|
||||||
|
vsquare_x(x_data, squared_x_data, m * k);
|
||||||
|
vsquare_y(y_data, squared_y_data, k * n);
|
||||||
|
matmul(squared_x_data, squared_y_data, o_data, m, n, k);
|
||||||
|
|
||||||
|
vsub(squared_xy_data, o_data, o_data, o_numel);
|
||||||
|
vscal(&scalar, o_data, o_data, o_numel);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(fusion_squared_mat_sub, ops::FusionSquaredMatSubOp,
|
||||||
|
ops::FusionSquaredMatSubOpMaker,
|
||||||
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(fusion_squared_mat_sub,
|
||||||
|
ops::FusionSquaredMatSubKernel<float>,
|
||||||
|
ops::FusionSquaredMatSubKernel<double>);
|
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using LoDTensor = framework::LoDTensor;
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
// ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
|
||||||
|
class FusionSquaredMatSubOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FusionSquaredMatSubOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue