commit
664159ad42
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
|
||||
GraphPatternDetector gpd;
|
||||
auto* pattern = gpd.mutable_pattern();
|
||||
|
||||
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
|
||||
->assert_is_op_input("sequence_conv")
|
||||
->assert_var_not_persistable();
|
||||
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
|
||||
fuse_pattern(x);
|
||||
|
||||
// Create New OpDesc
|
||||
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
|
||||
Node* eltadd_bias, Node* relu_out) {
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("fusion_seqconv_eltadd_relu");
|
||||
op_desc.SetInput("X", {input->Name()});
|
||||
op_desc.SetInput("Filter", {seqconv_weight->Name()});
|
||||
op_desc.SetInput("Bias", {eltadd_bias->Name()});
|
||||
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
|
||||
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
|
||||
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
|
||||
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
|
||||
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
|
||||
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
|
||||
op_desc.SetOutput("ColMat", {ColMat});
|
||||
op_desc.SetOutput("Out", {relu_out->Name()});
|
||||
scope->Var(ColMat)->GetMutable<LoDTensor>();
|
||||
|
||||
auto* op = graph->CreateOpNode(&op_desc);
|
||||
IR_NODE_LINK_TO(input, op);
|
||||
IR_NODE_LINK_TO(seqconv_weight, op);
|
||||
IR_NODE_LINK_TO(eltadd_bias, op);
|
||||
IR_NODE_LINK_TO(op, relu_out);
|
||||
return op;
|
||||
};
|
||||
|
||||
int fusion_count{0};
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
|
||||
|
||||
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
|
||||
relu_out);
|
||||
std::unordered_set<const Node*> marked_nodes(
|
||||
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
|
||||
GraphSafeRemoveNodes(graph, marked_nodes);
|
||||
++fusion_count;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
|
||||
return fusion_count;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
|
||||
AddStatis(fusion_count);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
|
||||
paddle::framework::ir::SeqConvEltAddReluFusePass);
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class SeqConvEltAddReluFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~SeqConvEltAddReluFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
|
||||
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,229 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h"
|
||||
#include <algorithm> // for min, max
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/fc_compute.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void FusionSeqConvEltAddReluOp::InferShape(
|
||||
framework::InferShapeContext* ctx) const {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of FusionSeqConvEltAddReluOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("Filter"),
|
||||
"Input(Filter) of FusionSeqConvEltAddReluOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("Bias"),
|
||||
"Input(Bias) of FusionSeqConvEltAddReluOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("Out"),
|
||||
"Output(Out) of FusionSeqConvEltAddReluOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("ColMat"),
|
||||
"Output(ColMat) of FusionSeqConvEltAddReluOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto w_dims = ctx->GetInputDim("Filter");
|
||||
int context_length = ctx->Attrs().Get<int>("contextLength");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->Attrs().Get<int>("contextStride") == 1,
|
||||
"Currently, FusionSeqConvEltAddReluOp only supports contextStride=1.");
|
||||
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
|
||||
"Input(X, Filter) should be 2-D tensor.");
|
||||
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
|
||||
"Input(X, Filter) should be 2-D tensor.");
|
||||
PADDLE_ENFORCE(w_dims[0] == context_length * x_dims[1],
|
||||
"Filter's height should be context_length * "
|
||||
"input_hidden_size .");
|
||||
PADDLE_ENFORCE_GT(context_length + ctx->Attrs().Get<int>("contextStart"), 0,
|
||||
"contextStart size should be smaller than contextLength.");
|
||||
|
||||
ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]});
|
||||
ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]});
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
|
||||
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
|
||||
void FusionSeqConvEltAddReluOpMaker::Make() {
|
||||
AddInput("X",
|
||||
"(LoDTensor) the input is a LodTensor, which support "
|
||||
"variable-time length input sequence. The underlying tensor in "
|
||||
"this LoDTensor is a matrix with shape (T X M), where T is the "
|
||||
"total time steps in this mini-batch, M is the dim size of x.");
|
||||
// PaddingData only support false yet, should be ensured at pass.
|
||||
AddInput("Filter",
|
||||
"(Tensor) same as the input(Filter) of sequence conv op is an "
|
||||
"learnable parameter."
|
||||
"This is a tensor with shape (K, N), where K is the "
|
||||
"context_length * dim size of x, N is the output feature size.");
|
||||
AddInput("Bias",
|
||||
"(Tensor) the learnable weights. shape (1, N), where N is the "
|
||||
"output feature size");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(LoDTensor) the output(Out) is a LodTensor, which support "
|
||||
"variable-time length output sequence. The underlying tensor in "
|
||||
"this LoDTensor is a matrix with shape (T, N), where, T is the "
|
||||
"total time steps in this mini-batch, N is the output feature size.");
|
||||
AddOutput("ColMat",
|
||||
"(Tensor) (T, K), where T is where T is the "
|
||||
"total time steps in this mini-batch, K is height of Filter")
|
||||
.AsIntermediate();
|
||||
AddAttr<int>("contextLength",
|
||||
"(int) the contextLength of FusionSeqConvEltAddReluOp is the "
|
||||
"height of the convolution kernel.")
|
||||
.GreaterThan(0);
|
||||
AddAttr<int>("contextStart",
|
||||
"(int, default:0) the contextStart of FusionSeqConvEltAddReluOp "
|
||||
"represents the beginning of the convolution of the number of "
|
||||
"rows of sequence, which can be negative. The negative number "
|
||||
"means to pad contextStart time-steps of zeros or learnable "
|
||||
"parameters at the beginning of each instance. The positive "
|
||||
"number means to skip contextStart time-steps of each "
|
||||
"instance.")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>(
|
||||
"contextStride",
|
||||
"(int, default:1) the contextStride of FusionSeqConvEltAddReluOp "
|
||||
"represents the stride length of convolution kernel. "
|
||||
"Currently, FusionSeqConvEltAddReluOp only supports"
|
||||
"contextStride=1.")
|
||||
.SetDefault(1)
|
||||
.GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
Fusion Sequence Conv and ElementwiseAdd Operator.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
using DeviceContext = paddle::platform::CPUDeviceContext;
|
||||
auto* x = ctx.Input<LoDTensor>("X");
|
||||
auto* w = ctx.Input<Tensor>("Filter");
|
||||
auto* b = ctx.Input<Tensor>("Bias");
|
||||
auto* y = ctx.Output<LoDTensor>("Out");
|
||||
auto* col = ctx.Output<Tensor>("ColMat");
|
||||
|
||||
auto x_lod = x->lod();
|
||||
auto x_dims = x->dims();
|
||||
auto w_dims = w->dims();
|
||||
PADDLE_ENFORCE_EQ(b->numel(), w_dims[1],
|
||||
"bias size should be equal to output feature size.");
|
||||
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL,
|
||||
"Only support one level sequence now.");
|
||||
|
||||
const T* x_data = x->data<T>();
|
||||
const T* w_data = w->data<T>();
|
||||
const T* b_data = b->data<T>();
|
||||
T* y_data = y->mutable_data<T>(ctx.GetPlace());
|
||||
T* col_data = col->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int context_start = ctx.Attr<int>("contextStart");
|
||||
int context_length = ctx.Attr<int>("contextLength");
|
||||
int up_pad = std::max(0, -context_start);
|
||||
int down_pad = std::max(0, context_start + context_length - 1);
|
||||
// im2col
|
||||
int src_mat_w = static_cast<int>(x_dims[1]);
|
||||
int src_mat_w_sz = src_mat_w * sizeof(T);
|
||||
int col_mat_w = static_cast<int>(w_dims[0]);
|
||||
int col_mat_w_sz = col_mat_w * sizeof(T);
|
||||
for (int i = 0; i < static_cast<int>(x_lod[0].size()) - 1; ++i) {
|
||||
int st = x_lod[0][i];
|
||||
int ed = x_lod[0][i + 1];
|
||||
const T* src_data = x_data + st * src_mat_w;
|
||||
T* dst_data = col_data + st * col_mat_w;
|
||||
int seq_len = ed - st;
|
||||
if (seq_len > up_pad + down_pad) {
|
||||
// zero all up_pad and fill data
|
||||
std::memset(dst_data, 0, up_pad * col_mat_w_sz);
|
||||
dst_data = dst_data + up_pad * src_mat_w;
|
||||
int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz;
|
||||
for (int j = 0; j < up_pad; ++j) {
|
||||
// blas.VCOPY?
|
||||
std::memcpy(dst_data, src_data, copy_size);
|
||||
dst_data += (col_mat_w - src_mat_w);
|
||||
copy_size += src_mat_w_sz;
|
||||
}
|
||||
// fill data
|
||||
for (int j = 0; j < seq_len - up_pad - down_pad; ++j) {
|
||||
std::memcpy(dst_data, src_data, copy_size);
|
||||
dst_data += col_mat_w;
|
||||
src_data += src_mat_w;
|
||||
}
|
||||
// zero all down_pad and fill data
|
||||
std::memset(dst_data, 0, down_pad * col_mat_w_sz);
|
||||
copy_size -= src_mat_w_sz;
|
||||
for (int j = 0; j < down_pad; ++j) {
|
||||
std::memcpy(dst_data, src_data, copy_size);
|
||||
dst_data += col_mat_w;
|
||||
src_data += src_mat_w;
|
||||
copy_size -= src_mat_w_sz;
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1);
|
||||
std::memset(dst_data, 0, seq_len * col_mat_w_sz);
|
||||
dst_data = dst_data + up_pad * src_mat_w;
|
||||
int zero_sz = up_pad * src_mat_w_sz;
|
||||
int cur_src_sz = seq_len * src_mat_w_sz;
|
||||
for (int j = 0; j < std::min(up_pad, seq_len); ++j) {
|
||||
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
|
||||
std::memcpy(dst_data, src_data, copy_size);
|
||||
dst_data += (col_mat_w - src_mat_w);
|
||||
zero_sz -= src_mat_w_sz;
|
||||
}
|
||||
// from bottom
|
||||
dst_data = col_data + ed * col_mat_w;
|
||||
src_data = x_data + st * src_mat_w;
|
||||
zero_sz = down_pad * src_mat_w_sz;
|
||||
for (int j = 1; j <= std::min(down_pad, seq_len); ++j) {
|
||||
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
|
||||
std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T),
|
||||
src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w,
|
||||
copy_size);
|
||||
dst_data -= col_mat_w;
|
||||
zero_sz -= src_mat_w_sz;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
||||
math::FCCompute<DeviceContext, T>(blas, x_dims[0], w_dims[1], w_dims[0],
|
||||
col_data, w_data, y_data, b_data, true);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fusion_seqconv_eltadd_relu, ops::FusionSeqConvEltAddReluOp,
|
||||
ops::FusionSeqConvEltAddReluOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(fusion_seqconv_eltadd_relu,
|
||||
ops::FusionSeqConvEltAddReluKernel<float>,
|
||||
ops::FusionSeqConvEltAddReluKernel<double>);
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
class FusionSeqConvEltAddReluOpMaker
|
||||
: public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import random
|
||||
from op_test import OpTest
|
||||
from test_seq_conv import seqconv
|
||||
|
||||
|
||||
class TestSeqConvEltAddRelu(OpTest):
|
||||
def set_conf(self):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = 'fusion_seqconv_eltadd_relu'
|
||||
self.lod = [[6, 4]]
|
||||
self.in_fea_size = 16
|
||||
self.out_fea_size = 8
|
||||
self.context_length = 4
|
||||
self.context_stride = 1
|
||||
self.context_start = 0
|
||||
self.set_conf()
|
||||
|
||||
assert self.context_stride == 1
|
||||
|
||||
T = sum(self.lod[0])
|
||||
x = np.random.uniform(-1, 1, [T, self.in_fea_size]).astype('float32')
|
||||
w = np.random.uniform(
|
||||
-1, 1, [self.in_fea_size * self.context_length,
|
||||
self.out_fea_size]).astype('float32')
|
||||
b = np.random.uniform(-2, 1, [1, self.out_fea_size]).astype('float32')
|
||||
out = seqconv(x, self.lod, w, self.context_length, self.context_start)
|
||||
out = np.maximum(out + b, 0)
|
||||
|
||||
self.inputs = {'X': (x, self.lod), 'Filter': w, 'Bias': b}
|
||||
self.attrs = {
|
||||
'contextStart': self.context_start,
|
||||
'contextLength': self.context_length,
|
||||
'contextStride': self.context_stride
|
||||
}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestSeqConvEltAddReluBS1(TestSeqConvEltAddRelu):
|
||||
def set_conf(self):
|
||||
self.lod = [[10]]
|
||||
|
||||
|
||||
class TestSeqConvEltAddReluBS1Case2(TestSeqConvEltAddRelu):
|
||||
def set_conf(self):
|
||||
self.lod = [[2]]
|
||||
|
||||
|
||||
class TestSeqConvEltAddReluCase1(TestSeqConvEltAddRelu):
|
||||
def set_conf(self):
|
||||
self.lod = [[3, 5, 1, 6]]
|
||||
self.context_length = 3
|
||||
self.context_start = -2
|
||||
|
||||
|
||||
class TestSeqConvEltAddReluCase2(TestSeqConvEltAddRelu):
|
||||
def set_conf(self):
|
||||
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
|
||||
self.in_fea_size = 2
|
||||
self.context_length = 4
|
||||
self.context_start = -1
|
||||
|
||||
|
||||
class TestSeqConvEltAddReluCase3(TestSeqConvEltAddRelu):
|
||||
def set_conf(self):
|
||||
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
|
||||
self.context_length = 5
|
||||
self.context_start = -4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue