Merge pull request #15237 from tensor-tang/fuse/seqpool_concat_2
Fuse/seqpool concat 2revert-15207-remove_op_handle_lock_and_fix_var
commit
48410b9bfe
@ -0,0 +1,194 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
|
||||
#define MAX_CONCAT_INPUTS 200
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
|
||||
const std::string& name_scope,
|
||||
int num_inputs) {
|
||||
auto is_concat_op_with_inputs = [](Node* x, int num) -> bool {
|
||||
return x && x->IsOp() && x->Op()->Type() == "concat" &&
|
||||
x->Op()->Input("X").size() == static_cast<size_t>(num);
|
||||
};
|
||||
|
||||
auto is_nth_input_var_of_concat = [=](Node* x, int idx) -> bool {
|
||||
return x && x->IsVar() && VarLinksToOp(x, "concat") &&
|
||||
x->outputs.size() == 1 && IsNthInput(x, x->outputs[0], "X", idx) &&
|
||||
is_concat_op_with_inputs(x->outputs[0], num_inputs);
|
||||
};
|
||||
|
||||
auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=](
|
||||
Node* x, const std::string& type, int idx) -> bool {
|
||||
bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" &&
|
||||
x->Op()->HasAttr("pooltype") &&
|
||||
boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type &&
|
||||
x->outputs.size() == 2; // seqpool should only have 2 outputs
|
||||
if (ok) {
|
||||
// only one output of seqpool_op is nth_input_var of concat
|
||||
// the other one should be unused empty var
|
||||
if (is_nth_input_var_of_concat(x->outputs[0], idx)) {
|
||||
ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0;
|
||||
} else {
|
||||
ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) &&
|
||||
x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0;
|
||||
}
|
||||
}
|
||||
return ok;
|
||||
};
|
||||
|
||||
auto* concat_op = pattern->NewNode(
|
||||
[=](Node* x) { return is_concat_op_with_inputs(x, num_inputs); },
|
||||
name_scope + "/concat_op");
|
||||
concat_op->assert_op_attr<int>("axis", 1);
|
||||
|
||||
auto* concat_out_var = pattern->NewNode(
|
||||
[=](Node* x) {
|
||||
return x && x->IsVar() && VarLinksFromOp(x, "concat") &&
|
||||
x->inputs.size() == 1 &&
|
||||
is_concat_op_with_inputs(x->inputs[0], num_inputs);
|
||||
},
|
||||
name_scope + "/concat_out_var");
|
||||
concat_out_var->assert_is_only_output_of_op("concat");
|
||||
|
||||
std::vector<PDNode*> seqpool_ops_input_var(num_inputs);
|
||||
std::vector<PDNode*> seqpool_ops_output_var(num_inputs);
|
||||
std::vector<PDNode*> seqpool_ops(num_inputs);
|
||||
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
seqpool_ops_output_var[i] = pattern->NewNode(
|
||||
[=](Node* x) {
|
||||
return x && x->IsVar() && is_nth_input_var_of_concat(x, i) &&
|
||||
x->inputs.size() == 1 &&
|
||||
is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0],
|
||||
"SUM", i);
|
||||
},
|
||||
name_scope + "/sequence_pool_out_" + std::to_string(i));
|
||||
|
||||
seqpool_ops[i] = pattern->NewNode(
|
||||
[=](Node* x) {
|
||||
return x && x->IsOp() &&
|
||||
is_seqpool_op_with_pootype_of_nth_input_of_concat(x, "SUM", i);
|
||||
},
|
||||
name_scope + "/sequence_pool_op_" + std::to_string(i));
|
||||
|
||||
seqpool_ops_input_var[i] = pattern->NewNode(
|
||||
[=](Node* x) {
|
||||
return x && x->IsVar() && x->outputs.size() >= 1 &&
|
||||
is_seqpool_op_with_pootype_of_nth_input_of_concat(
|
||||
x->outputs[0], "SUM", i);
|
||||
},
|
||||
name_scope + "/sequence_pool_in_" + std::to_string(i));
|
||||
|
||||
// Links
|
||||
seqpool_ops[i]
|
||||
->LinksFrom({seqpool_ops_input_var[i]})
|
||||
.LinksTo({seqpool_ops_output_var[i]});
|
||||
}
|
||||
concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var});
|
||||
return concat_out_var;
|
||||
}
|
||||
|
||||
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
||||
int num_inputs) {
|
||||
GraphPatternDetector gpd;
|
||||
auto* pattern = gpd.mutable_pattern();
|
||||
BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs);
|
||||
|
||||
auto retrieve_node = [](const std::string& name,
|
||||
const GraphPatternDetector::subgraph_t& subgraph,
|
||||
const PDPattern& pat) -> Node* {
|
||||
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)),
|
||||
"pattern has no Node called %s", name.c_str());
|
||||
Node* p = subgraph.at(pat.RetrieveNode(name));
|
||||
PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str());
|
||||
return p;
|
||||
};
|
||||
|
||||
int fusion_count{0};
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle SeqPool Concat fuse";
|
||||
std::vector<std::string> input_names(num_inputs);
|
||||
std::vector<Node*> input_vars(num_inputs);
|
||||
auto& fused_pattern = gpd.pattern();
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
input_vars[i] =
|
||||
retrieve_node(name_scope + "/sequence_pool_in_" + std::to_string(i),
|
||||
subgraph, fused_pattern);
|
||||
input_names[i] = input_vars[i]->Name();
|
||||
}
|
||||
auto* concat_op =
|
||||
retrieve_node(name_scope + "/concat_op", subgraph, fused_pattern);
|
||||
auto* concat_out_var =
|
||||
retrieve_node(name_scope + "/concat_out_var", subgraph, fused_pattern);
|
||||
auto* seqpool_op0 = retrieve_node(name_scope + "/sequence_pool_op_0",
|
||||
subgraph, fused_pattern);
|
||||
|
||||
// Create New OpDesc
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("fusion_seqpool_concat");
|
||||
op_desc.SetInput("X", input_names);
|
||||
op_desc.SetAttr("pooltype", seqpool_op0->Op()->GetAttr("pooltype"));
|
||||
op_desc.SetAttr("axis", concat_op->Op()->GetAttr("axis"));
|
||||
op_desc.SetOutput("Out", {concat_out_var->Name()});
|
||||
auto* op = graph->CreateOpNode(&op_desc);
|
||||
for (size_t i = 0; i < input_vars.size(); ++i) {
|
||||
IR_NODE_LINK_TO(input_vars[i], op);
|
||||
}
|
||||
IR_NODE_LINK_TO(op, concat_out_var);
|
||||
|
||||
std::unordered_set<const Node*> marked_nodes;
|
||||
for (auto& item : subgraph) {
|
||||
marked_nodes.insert(item.second);
|
||||
}
|
||||
for (size_t i = 0; i < input_vars.size(); ++i) {
|
||||
marked_nodes.erase(input_vars[i]);
|
||||
}
|
||||
marked_nodes.erase(concat_out_var);
|
||||
GraphSafeRemoveNodes(graph, marked_nodes);
|
||||
++fusion_count;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
return fusion_count;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SeqPoolConcatFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
int fusion_count = 0;
|
||||
for (int i = MAX_CONCAT_INPUTS; i > 0; --i) {
|
||||
fusion_count += BuildFusion(
|
||||
graph.get(), name_scope_ + "/" + std::to_string(i), param_scope(), i);
|
||||
}
|
||||
AddStatis(fusion_count);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(seqpool_concat_fuse_pass,
|
||||
paddle::framework::ir::SeqPoolConcatFusePass);
|
@ -0,0 +1,38 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class SeqPoolConcatFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~SeqPoolConcatFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
|
||||
const std::string name_scope_{"seqpool_concat_fuse"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,132 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/fusion_seqpool_concat_op.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/jit/kernels.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void FusionSeqPoolConcatOp::InferShape(
|
||||
framework::InferShapeContext* ctx) const {
|
||||
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
|
||||
"Inputs(X) of FusionSeqPoolConcatOp should be empty.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FusionSeqPoolConcatOp should not be null.");
|
||||
int axis = ctx->Attrs().Get<int>("axis");
|
||||
PADDLE_ENFORCE_EQ(axis, 1,
|
||||
"FusionSeqPoolConcatOp only supports concat axis=1 yet.");
|
||||
|
||||
auto ins_dims = ctx->GetInputsDim("X");
|
||||
const size_t n = ins_dims.size();
|
||||
PADDLE_ENFORCE_GT(n, 0UL, "Input tensors count should > 0.");
|
||||
if (n == 1) {
|
||||
LOG(WARNING) << "Only have one input, may waste memory";
|
||||
}
|
||||
|
||||
// The output height should be confirmed in Compute,
|
||||
// since input lod is not accessible here.
|
||||
PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2UL,
|
||||
"The dims size of first input should be 2.");
|
||||
ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast<int>(n)});
|
||||
}
|
||||
|
||||
framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
return framework::OpKernelType(
|
||||
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace());
|
||||
}
|
||||
|
||||
void FusionSeqPoolConcatOpMaker::Make() {
|
||||
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
|
||||
AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
|
||||
AddAttr<std::string>("pooltype",
|
||||
"(string, default 'AVERAGE') some of the pooling "
|
||||
"pooltype of SequencePoolOp.")
|
||||
.SetDefault("SUM")
|
||||
.InEnum({"AVERAGE", "SUM", "SQRT"});
|
||||
AddAttr<int>("axis",
|
||||
"The axis along which the input tensors will be concatenated.")
|
||||
.SetDefault(1);
|
||||
AddComment(R"DOC(
|
||||
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto ins = ctx.MultiInput<LoDTensor>("X");
|
||||
auto* out = ctx.Output<LoDTensor>("Out");
|
||||
std::string pooltype = ctx.Attr<std::string>("pooltype");
|
||||
auto x0_lod = ins[0]->lod();
|
||||
auto x0_dims = ins[0]->dims();
|
||||
auto y_dims = out->dims();
|
||||
size_t bs = x0_lod[0].size() - 1;
|
||||
out->Resize({static_cast<int64_t>(bs), y_dims[1]});
|
||||
framework::LoD y_lod(1);
|
||||
y_lod[0].resize(bs + 1);
|
||||
for (size_t i = 0; i <= bs; ++i) {
|
||||
y_lod[0][i] = i;
|
||||
}
|
||||
out->set_lod(y_lod);
|
||||
auto place = ctx.GetPlace();
|
||||
T* y_data = out->mutable_data<T>(place);
|
||||
|
||||
int w = ins[0]->numel() / x0_dims[0];
|
||||
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
|
||||
"The output of dims[1] should be dividable of w");
|
||||
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum);
|
||||
if (pooltype == "AVERAGE") {
|
||||
attr.type = jit::SeqPoolType::kAvg;
|
||||
} else if (pooltype == "SQRT") {
|
||||
attr.type = jit::SeqPoolType::kSqrt;
|
||||
}
|
||||
auto seqpool =
|
||||
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
|
||||
attr);
|
||||
size_t n = ins.size();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
auto x_dims = ins[i]->dims();
|
||||
auto x_lod = ins[i]->lod()[0];
|
||||
const T* src = ins[i]->data<T>();
|
||||
T* dst = y_data + i * w;
|
||||
PADDLE_ENFORCE_EQ(static_cast<int>(ins[i]->numel() / x_dims[0]), w,
|
||||
"Width of all inputs should be equal.");
|
||||
PADDLE_ENFORCE_EQ(x_lod.size(), bs + 1,
|
||||
"Batchsize of all inputs should be equal.");
|
||||
for (size_t j = 0; j < bs; ++j) {
|
||||
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
|
||||
seqpool(src, dst, &attr);
|
||||
dst += n * w;
|
||||
src += attr.h * attr.w;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fusion_seqpool_concat, ops::FusionSeqPoolConcatOp,
|
||||
ops::FusionSeqPoolConcatOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(fusion_seqpool_concat,
|
||||
ops::FusionSeqPoolConcatKernel<float>,
|
||||
ops::FusionSeqPoolConcatKernel<double>);
|
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class FusionSeqPoolConcatOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
class FusionSeqPoolConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
from test_reorder_lod_tensor import convert_to_offset
|
||||
from test_seq_pool import compute_seqpool_sum, compute_seqpool_avg, compute_seqpool_sqrt
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOp(OpTest):
|
||||
def setUp(self):
|
||||
self.w = 11
|
||||
self.lods = [[[2, 3, 5]], [[1, 5, 2]]]
|
||||
self.set_conf()
|
||||
self.set_pooltype()
|
||||
self.op_type = 'fusion_seqpool_concat'
|
||||
self.axis = 1
|
||||
bs = len(self.lods[0][0])
|
||||
inputs = []
|
||||
outs = []
|
||||
i = 0
|
||||
for lod in self.lods:
|
||||
assert bs == len(lod[0]), 'All lod size should be equal'
|
||||
x = np.random.uniform(0.1, 1,
|
||||
[sum(lod[0]), self.w]).astype('float32')
|
||||
offset = convert_to_offset(lod)
|
||||
out = np.zeros((bs, self.w)).astype('float32')
|
||||
if self.pooltype == "SUM":
|
||||
compute_seqpool_sum(x, offset, out)
|
||||
elif self.pooltype == "AVERAGE":
|
||||
compute_seqpool_avg(x, offset, out)
|
||||
elif self.pooltype == "SQRT":
|
||||
compute_seqpool_sqrt(x, offset, out)
|
||||
else:
|
||||
raise Exception("Unsupported pool type!")
|
||||
inputs.append(('x_{0}'.format(i), (x, lod)))
|
||||
outs.append(out)
|
||||
i = i + 1
|
||||
|
||||
self.inputs = {'X': inputs}
|
||||
self.outputs = {'Out': np.concatenate(outs, axis=self.axis)}
|
||||
self.attrs = {
|
||||
'pooltype': self.pooltype,
|
||||
'axis': self.axis,
|
||||
}
|
||||
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "SUM"
|
||||
|
||||
def set_conf(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase1(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1]]]
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase2(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1]], [[1]], [[1]]]
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase3(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1, 3, 4, 6]]]
|
||||
self.w = 10
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase4(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[2, 13, 4]], [[1, 1, 1]], [[5, 3, 1]], [[9, 10, 3]]]
|
||||
self.w = 3
|
||||
|
||||
|
||||
## test avg pool and sqrt
|
||||
def create_test_avg_sqrt_class(parent):
|
||||
class TestSeqPoolAvgCase(parent):
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "AVERAGE"
|
||||
|
||||
class TestSeqPoolSqrtCase(parent):
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "SQRT"
|
||||
|
||||
cls_name_avg = "{0}_{1}".format(parent.__name__, "avg")
|
||||
cls_name_sqrt = "{0}_{1}".format(parent.__name__, "sqrt")
|
||||
TestSeqPoolAvgCase.__name__ = cls_name_avg
|
||||
TestSeqPoolSqrtCase.__name__ = cls_name_sqrt
|
||||
globals()[cls_name_avg] = TestSeqPoolAvgCase
|
||||
globals()[cls_name_sqrt] = TestSeqPoolSqrtCase
|
||||
|
||||
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOp)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase1)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase2)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase3)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue