Add sequence_project_op (use im2col)

revert-4814-Add_sequence_project_op
chengduoZH 7 years ago
parent e593113a84
commit 1e60c9b2e8

@ -46,7 +46,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope frame
set(EXECUTOR_TEST_OP elementwise_add_op gaussian_random_op feed_op fetch_op set(EXECUTOR_TEST_OP elementwise_add_op gaussian_random_op feed_op fetch_op
mul_op sum_op squared_l2_distance_op fill_constant_op sgd_op mean_op) mul_op sum_op squared_l2_distance_op fill_constant_op sgd_op mean_op)
if(WITH_GPU) if(WITH_GPU)
nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) # nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP})
else() else()
cc_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) cc_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP})
endif() endif()

@ -140,8 +140,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride, int pad, int row_begin, int row_end) {
int padding_width) { int stride_height = stride;
int stride_width = 0;
int padding_height = pad;
int padding_width = 0;
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
@ -149,13 +152,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; // int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
@ -166,13 +169,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_height; col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width; col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) * int col_offset =
input_channels + ((((col_row_idx - row_begin) * output_width + col_col_idx) *
channel) * input_channels +
filter_height + channel) *
filter_row_idx) * filter_height +
filter_width + filter_row_idx) *
filter_col_idx; filter_width +
filter_col_idx;
if (im_row_offset < 0 || im_row_offset >= input_height || if (im_row_offset < 0 || im_row_offset >= input_height ||
im_col_offset < 0 || im_col_offset >= input_width) { im_col_offset < 0 || im_col_offset >= input_width) {
col_data[col_offset] = T(0); col_data[col_offset] = T(0);
@ -200,8 +204,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride, int pad,
int stride_width, int padding_height, int padding_width) { int row_start, int row_end) {
int stride_height = stride;
int stride_width = 0;
int padding_height = pad;
int padding_width = 0;
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
@ -209,30 +217,31 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; // int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = row_start; col_row_idx < row_end; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = int im_row_offset = // change or not
col_row_idx * stride_height + filter_row_idx - padding_height; col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width; col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) * int col_offset =
input_channels + ((((col_row_idx - row_start) * output_width + col_col_idx) *
channel) * input_channels +
filter_height + channel) *
filter_row_idx) * filter_height +
filter_width + filter_row_idx) *
filter_col_idx; filter_width +
filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < input_height && if (im_row_offset >= 0 && im_row_offset < input_height &&
im_col_offset >= 0 && im_col_offset < input_width) { im_col_offset >= 0 && im_col_offset < input_width) {
int im_offset = int im_offset =

@ -199,7 +199,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int input_height, int input_width, int filter_height, int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width,
int output_height, int output_width) { int output_height, int output_width, int row_begin,
int row_end) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < input_channels;
@ -207,7 +208,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height; int height_offset =
idy + (shid + row_begin) * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width; channelid * input_height * input_width;
@ -238,8 +240,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride, int pad, int row_begin, int row_end) {
int padding_width) { int stride_height = stride;
int stride_width = 0;
int padding_height = pad;
int padding_width = 0;
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
@ -247,7 +253,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int output_height = row_end - row_begin; // col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int block_dim_x = 0; int block_dim_x = 0;
@ -275,7 +281,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width); padding_height, padding_width, output_height, output_width, row_begin,
row_end);
} }
}; };
@ -284,15 +291,18 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int input_height, int input_width, int filter_height, int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width,
int output_height, int output_width) { int output_height, int output_width, int row_begin,
int row_end) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
// if (shid < row_begin || shid > row_end) return;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < input_channels;
channelid += blockDim.z) { channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height; int height_offset =
idy + (shid + row_begin) * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width; channelid * input_height * input_width;
@ -321,8 +331,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride, int pad,
int stride_width, int padding_height, int padding_width) { int row_begin, int row_end) {
int stride_height = stride;
int stride_width = 0;
int padding_height = pad;
int padding_width = 0;
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
@ -330,7 +344,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int output_height = row_end - row_begin; // col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int block_dim_x = 0; int block_dim_x = 0;
@ -358,7 +372,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width); padding_height, padding_width, output_height, output_width, row_begin,
row_end);
} }
}; };

@ -79,7 +79,8 @@ void testIm2col() {
im2col_ocf; im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding); im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); im2col_ocf(*context, input, output_ocf, stride, padding, 0,
output_height * output_width);
float* out_cfo_ptr; float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {

@ -0,0 +1,166 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_project_op.h"
namespace paddle {
namespace operators {
class SequenceProjectOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceProjectOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceProjectOp should not be null.");
auto in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(in_dims.size() == 2, "Input(X) should be 2-D tensor.");
int context_length = ctx->Attrs().Get<int>("context_length");
bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable");
int context_start = ctx->Attrs().Get<int>("context_start");
if (padding_trainable) {
PADDLE_ENFORCE(
ctx->HasInput("PaddingData"),
"Output(PaddingData) of SequenceProjectOp should not be null.");
framework::DDim padding_dim = ctx->GetOutputDim("PaddingData");
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int total_pad = up_pad + down_pad;
int input_width = static_cast<int>(in_dims[1]);
PADDLE_ENFORCE(padding_dim.size() == 2,
"Input(PaddingData) should be 2-D tensor.");
PADDLE_ENFORCE(
padding_dim[0] == total_pad && padding_dim[1] == input_width,
"Input(PaddingData)'s shape is not consistent with 'context_start' "
"and 'context_length'.");
if (context_start == 0 && context_length == 1) {
PADDLE_THROW(
"if context_start == 0 && context_length == 1, padding_trainable "
"should be false.");
}
}
in_dims[1] = in_dims[1] * context_length;
ctx->SetOutputDim("Out", in_dims);
}
};
class SequenceProjectGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
if (ctx->Attrs().Get<bool>("padding_trainable")) {
PADDLE_ENFORCE(
ctx->HasOutput("PaddingData"),
"Output(PaddingData) of SequenceProjectOp should not be null.");
}
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceProjectOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"A float LoDTensor, the variable-length input of SequenceProjectOp");
AddOutput(
"Out",
"A float LoDTensor, the variable-length output of SequenceProjectOp.");
AddOutput("PaddingData",
"A float LoDTensor, the padding data of SequenceProjectOp.");
AddAttr<bool>("padding_trainable",
"(bool, default false) the padding data of SequenceProjectOp "
"is trainable or not.")
.SetDefault(false);
AddAttr<int>("context_length",
"(int, default 3) the stride of SequenceProjectOp.")
.SetDefault(3)
.GreaterThan(0);
AddAttr<int>("context_start",
"(int, default 0) the xx of SequenceProjectOp.")
.SetDefault(0);
AddAttr<int>("context_stride",
"(int, default 1) the xx of SequenceProjectOp.")
.SetDefault(1)
.GreaterThan(0);
AddComment(R"DOC(
SequenceProjectOp projects features of context_length time-steps of each instance.
For a mini-batch of 2 variable lengths sentences, containing 3, and 1 time-steps:
Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3, 4].
Besides, for the sake of simplicity, we assume M=1 and N=2.
X = [[a1, a2,
b1, b2.
c1, c2]
[d1, d2]]
This is to say that input (X) has 4 words and the dimension of each word
representation is 2.
- Case1:
If we use zero to pad instead of learned weight to pad,
and the context_lenth is 3, the output (Out) is:
Out = [0, 0, a1, a2, b1, b2;
a1, a2, b1, b2, c1, c2;
b1, b2, c1, c2, 0, 0;
0, 0, d1, d2, 0, 0]
- Case2:
// If we use zero to pad instead of learned weight to pad,
// and the context_lenth is 3, the output (Out) is:
//
// Out = [0, 0, a1, a2, b1, b2;
// a1, a2, b1, b2, c1, c2;
// b1, b2, c1, c2, 0, 0;
// 0, 0, d1, d2, 0, 0]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_project, ops::SequenceProjectOp,
ops::SequenceProjectOpMaker, sequence_project_grad,
ops::SequenceProjectGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_project,
ops::SequenceProjectKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_project_grad,
ops::SequenceProjectGradKernel<paddle::platform::CPUPlace, float>);

@ -0,0 +1,25 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_project_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_project,
ops::SequenceProjectKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_project_grad,
ops::SequenceProjectGradKernel<paddle::platform::GPUPlace, float>);

File diff suppressed because it is too large Load Diff

@ -0,0 +1,96 @@
import unittest
import numpy as np
from op_test import OpTest
class TestSeqProject(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = 'sequence_project'
# one level, batch size
x = np.random.uniform(
0.1, 1, [self.input_size[0], self.input_size[1]]).astype('float32')
lod = [[0, 4, 5, 8, self.input_size[0]]]
self.begin_pad = np.max([0, -self.context_start])
self.end_pad = np.max([0, self.context_start + self.context_length - 1])
self.total_pad = self.begin_pad + self.end_pad
w = np.ones((self.total_pad, self.input_size[1])) * 100
self.inputs = {'X': (x, lod), 'PaddingData': w}
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable
}
out = np.zeros((self.input_size[0], self.input_size[1] *
self.context_length)).astype('float32')
self.outputs = {'Out': out}
self.compute()
def compute(self):
x, lod = self.inputs['X']
w = self.inputs['PaddingData']
out = self.outputs['Out']
lod = lod[0]
for i in range(len(lod) - 1):
for j in range(self.context_length):
in_begin = lod[i] + self.context_start + j
in_end = lod[i + 1] + self.context_start + j
out_begin = lod[i]
out_end = lod[i + 1]
if in_begin < lod[i]:
pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = w[j:pad_size, :]
out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:(
j + 1) * self.input_size[1]] = sub_w
# pass
out_begin = lod[i] + pad_size
in_begin = lod[i]
if in_end > lod[i + 1]:
pad_size = np.min(
[in_end - lod[i + 1], lod[i + 1] - lod[i]])
out_sub = out[lod[i + 1] - pad_size:lod[i + 1], :]
if self.padding_trainable:
sub_w = w[j - pad_size:j, :]
out[lod[i + 1] - pad_size:lod[i + 1], j * self.
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
# pass
in_end = lod[i + 1]
out_end = lod[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
self.input_size[1]] += in_sub
def init_test_case(self):
self.input_size = [11, 23]
self.op_type = "sequence_project"
self.context_start = -1
self.context_length = 3
self.padding_trainable = False
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
# class TestSeqAvgPool2D(TestSeqProject):
# def init_test_case(self):
# self.input_size = [11, 23]
# self.op_type = "sequence_project"
#
# self.context_start = -1
# self.context_length = 3
# self.padding_trainable = True
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save