|
|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
@ -22,15 +22,20 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
|
|
|
|
|
"Inputs(X) of FusionSeqConcatFCOp should larger than 1.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("FCWeight"),
|
|
|
|
|
"Input(FCWeight) of FusionSeqConcatFC should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionSeqConcatFC should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("FCOut"),
|
|
|
|
|
"Output(FCOut) of FusionSeqConcatFC should not be null.");
|
|
|
|
|
void FusionSeqExpandConcatFCOp::InferShape(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
ctx->Inputs("X").size(), 1UL,
|
|
|
|
|
"Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("FCWeight"),
|
|
|
|
|
"Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("FCOut"),
|
|
|
|
|
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto ins_dims = ctx->GetInputsDim("X");
|
|
|
|
|
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
|
|
|
|
@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
ctx->ShareLoD("X", "Out", 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType(
|
|
|
|
|
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FusionSeqConcatFCOpMaker::Make() {
|
|
|
|
|
void FusionSeqExpandConcatFCOpMaker::Make() {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(LoDTensor) input LodDTensors, the first one must be have ref lod "
|
|
|
|
|
"for sequence expand, and the rest input should have same lod.")
|
|
|
|
@ -100,7 +105,7 @@ The concat axis should be 1.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
|
|
|
|
|
class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp,
|
|
|
|
|
ops::FusionSeqConcatFCOpMaker,
|
|
|
|
|
REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
|
|
|
|
|
ops::FusionSeqExpandConcatFCOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc,
|
|
|
|
|
ops::FusionSeqConcatFCKernel<float>,
|
|
|
|
|
ops::FusionSeqConcatFCKernel<double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
|
|
|
|
|
ops::FusionSeqExpandConcatFCOpKernel<float>,
|
|
|
|
|
ops::FusionSeqExpandConcatFCOpKernel<double>);
|