commit
fd3e32ea7d
@ -0,0 +1,168 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/sequence_expand_as_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
|
||||
class SequenceExpandAsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of SequenceExpandAsOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
||||
"Input(Y) of SequenceExpandAsOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of SequenceExpandAsOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto out_dims = x_dims;
|
||||
|
||||
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
||||
"Dimension number of Input(X) should be at least 2.");
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
framework::Variable* x_var =
|
||||
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
|
||||
framework::Variable* y_var =
|
||||
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
|
||||
|
||||
auto& x_dim = x_var->Get<LoDTensor>().dims();
|
||||
auto& y_lod = y_var->Get<LoDTensor>().lod();
|
||||
|
||||
PADDLE_ENFORCE_EQ(y_lod.size(), 1,
|
||||
"Level number of Input(Y)'s lod should be 1.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dim[0]), y_lod[0].size() - 1,
|
||||
"The first dimension of Input(X) should be equal "
|
||||
"to the size of Input(Y)'s 0 level lod.");
|
||||
|
||||
int64_t out_first_dim = 0;
|
||||
if (y_lod[0].size() <= 1) {
|
||||
out_first_dim = x_dims[0];
|
||||
} else {
|
||||
for (size_t i = 1; i < y_lod[0].size(); ++i) {
|
||||
out_first_dim += (y_lod[0][i] - y_lod[0][i - 1]);
|
||||
}
|
||||
}
|
||||
out_dims[0] = out_first_dim;
|
||||
} else {
|
||||
out_dims[0] = -1;
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
ctx->ShareLoD("Y", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class SequenceExpandAsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
|
||||
"level is at most 1.");
|
||||
AddInput("Y",
|
||||
"(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
|
||||
"lod (specified level) is referred by Input(X).");
|
||||
AddOutput("Out",
|
||||
"(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
|
||||
"generated from Input(X) by referring lod of Input(Y).");
|
||||
AddComment(R"DOC(
|
||||
Sequence Expand As Operator.
|
||||
|
||||
This operator expands `X` according to the zeroth level lod of `Y`. Current
|
||||
implementation requires the level number of Input(Y)'s lod should be 1, and
|
||||
the first dimension of Input(X) should be equal to the size of Input(Y)'s zeroth
|
||||
level lod, and lod of Input(X) is not considered.
|
||||
|
||||
Following are cases to better explain how this works:
|
||||
|
||||
Case 1:
|
||||
|
||||
Given a 1-level LoDTensor input(X)
|
||||
X.data = [[a], [b], [c], [d]]
|
||||
X.dims = [4, 1]
|
||||
and input(Y)
|
||||
Y.lod = [[0, 3, 6, 7, 8]]
|
||||
ref_level: 0
|
||||
then we get 1-level LoDTensor
|
||||
Out.lod = [[0, 3, 6, 7, 8]]
|
||||
Out.data = [[a], [a], [a], [b], [b], [b], [c], [d]]
|
||||
Out.dims = [8, 1]
|
||||
|
||||
Case 2:
|
||||
|
||||
Given a common Tensor input(X)
|
||||
X.data = [[a, b], [c, d], [e, f]]
|
||||
X.dims = [3, 2]
|
||||
and input(Y)
|
||||
Y.lod = [[0, 2, 3, 6]]
|
||||
ref_level: 0
|
||||
then we get a common LoDTensor
|
||||
Out.lod = [[0, 2, 3, 6]]
|
||||
Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]]
|
||||
Out.dims = [6, 2]
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
ctx->ShareLoD("X", x_grad_name);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(sequence_expand_as, ops::SequenceExpandAsOp,
|
||||
ops::SequenceExpandAsOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_expand_as,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_expand_as_grad,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext,
|
||||
int64_t>);
|
@ -0,0 +1,134 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/operators/sequence_expand_as_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename T>
|
||||
static __global__ void sequence_expand_as_kernel(const T *in_data,
|
||||
const size_t *expand_offset,
|
||||
const size_t src_hight,
|
||||
const size_t src_widht,
|
||||
T *out_data) {
|
||||
for (int h_id = blockIdx.x; h_id < src_hight; h_id += gridDim.x) {
|
||||
int span = expand_offset[h_id + 1] - expand_offset[h_id];
|
||||
if (span == 0) continue;
|
||||
const T *src = in_data + h_id * src_widht;
|
||||
for (int w_id = threadIdx.x; w_id < src_widht; w_id += blockDim.x) {
|
||||
T ele = src[w_id];
|
||||
int offset = expand_offset[h_id] * src_widht;
|
||||
for (int k = 0; k < span; ++k) {
|
||||
out_data[offset + k * src_widht + w_id] = ele;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void sequence_expand_as_grad_kernel(
|
||||
const T *dout_data, const size_t *expand_offset, const size_t dst_hight,
|
||||
const size_t dst_width, T *dx_data) {
|
||||
for (int h_id = blockIdx.x; h_id < dst_hight; h_id += gridDim.x) {
|
||||
T *dst = dx_data + h_id * dst_width;
|
||||
int span = expand_offset[h_id + 1] - expand_offset[h_id];
|
||||
|
||||
for (int w_id = threadIdx.x; w_id < dst_width; w_id += blockDim.x) {
|
||||
T result = 0;
|
||||
for (int k = 0; k < span; ++k) {
|
||||
int offset = (expand_offset[h_id] + k) * dst_width;
|
||||
const T *src = dout_data + offset;
|
||||
result += src[w_id];
|
||||
}
|
||||
dst[w_id] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
|
||||
void operator()(
|
||||
const platform::CUDADeviceContext &context, const LoDTensor &x,
|
||||
const framework::Vector<size_t> &ref_lod, /*expand referenced lod*/
|
||||
LoDTensor *out) {
|
||||
int hight = x.dims()[0];
|
||||
int width = framework::product(x.dims()) / hight;
|
||||
|
||||
const int kThreadsPerBlock = 1024;
|
||||
int thread_x = kThreadsPerBlock;
|
||||
if (width < kThreadsPerBlock) { // block_cols is aligned by 32.
|
||||
thread_x = ((width + 31) >> 5) << 5;
|
||||
}
|
||||
|
||||
int max_threads = context.GetMaxPhysicalThreadCount();
|
||||
int block_x = std::max(max_threads / thread_x, 1);
|
||||
|
||||
dim3 block_size(thread_x);
|
||||
dim3 grid_size(block_x);
|
||||
sequence_expand_as_kernel<<<grid_size, block_size, 0, context.stream()>>>(
|
||||
x.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight, width,
|
||||
out->mutable_data<T>(context.GetPlace()));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SequenceExpandAsGradFunctor<platform::CUDADeviceContext, T> {
|
||||
void operator()(const platform::CUDADeviceContext &context,
|
||||
const LoDTensor &dout,
|
||||
const framework::Vector<size_t> &ref_lod, /*expand based lod*/
|
||||
LoDTensor *dx) {
|
||||
int hight = dx->dims()[0];
|
||||
int width = framework::product(dx->dims()) / hight;
|
||||
|
||||
const int kThreadsPerBlock = 1024;
|
||||
int thread_x = kThreadsPerBlock;
|
||||
if (width < kThreadsPerBlock) { // block_cols is aligned by 32.
|
||||
thread_x = ((width + 31) >> 5) << 5;
|
||||
}
|
||||
|
||||
int max_threads = context.GetMaxPhysicalThreadCount();
|
||||
int block_x = std::max(max_threads / thread_x, 1);
|
||||
|
||||
dim3 block_size(thread_x);
|
||||
dim3 grid_size(block_x);
|
||||
sequence_expand_as_grad_kernel<<<grid_size, block_size, 0,
|
||||
context.stream()>>>(
|
||||
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight, width,
|
||||
dx->mutable_data<T>(context.GetPlace()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sequence_expand_as,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sequence_expand_as_grad,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext,
|
||||
double>,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext,
|
||||
int64_t>);
|
@ -0,0 +1,148 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric> // std::iota
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
struct SequenceExpandFunctor {
|
||||
void operator()(
|
||||
const DeviceContext &ctx, const framework::LoDTensor &x,
|
||||
const framework::Vector<size_t> &ref_lod, /*expand referenced lod*/
|
||||
framework::LoDTensor *out);
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
struct SequenceExpandAsGradFunctor {
|
||||
void operator()(
|
||||
const DeviceContext &ctx, const framework::LoDTensor &dout,
|
||||
const framework::Vector<size_t> &ref_lod, /*expand referenced lod*/
|
||||
framework::LoDTensor *dx);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
|
||||
void operator()(
|
||||
const platform::CPUDeviceContext &context, const framework::LoDTensor &x,
|
||||
const framework::Vector<size_t> &ref_lod, /*expand referenced lod*/
|
||||
framework::LoDTensor *out) {
|
||||
int64_t hight = x.dims()[0];
|
||||
int64_t width = framework::product(x.dims()) / hight;
|
||||
|
||||
const T *in_data = x.data<T>();
|
||||
T *out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (int h_id = 0; h_id < hight; ++h_id) {
|
||||
size_t span = ref_lod[h_id + 1] - ref_lod[h_id];
|
||||
if (span == 0) continue;
|
||||
const T *src = in_data + h_id * width;
|
||||
for (int64_t w_id = 0; w_id < width; ++w_id) {
|
||||
T ele = src[w_id];
|
||||
size_t offset = ref_lod[h_id] * width;
|
||||
for (size_t k = 0; k < span; ++k) {
|
||||
out_data[offset + k * width + w_id] = ele;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SequenceExpandAsKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *x = context.Input<framework::LoDTensor>("X");
|
||||
auto *y = context.Input<framework::LoDTensor>("Y");
|
||||
auto *out = context.Output<framework::LoDTensor>("Out");
|
||||
|
||||
auto &y_lod = y->lod();
|
||||
PADDLE_ENFORCE_EQ(y_lod.size(), 1, "LoD of Y should be 1.");
|
||||
PADDLE_ENFORCE_GT(y_lod[0].size(), 1, ".");
|
||||
|
||||
out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto &dev_ctx = context.template device_context<DeviceContext>();
|
||||
SequenceExpandFunctor<DeviceContext, T> seq_espand_functor;
|
||||
seq_espand_functor(dev_ctx, *x, y_lod[0], out);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
*Given Grad(Out)
|
||||
*
|
||||
* Grad(Out).lod = [[0, 3, 6]]
|
||||
* Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
* Then
|
||||
* Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)]
|
||||
* = [0.6, 1.5]
|
||||
* Grad(X).lod = Input(X).lod
|
||||
*
|
||||
* */
|
||||
template <typename T>
|
||||
struct SequenceExpandAsGradFunctor<platform::CPUDeviceContext, T> {
|
||||
void operator()(
|
||||
const platform::CPUDeviceContext &context,
|
||||
const framework::LoDTensor &dout,
|
||||
const framework::Vector<size_t> &ref_lod, /*expand referenced lod*/
|
||||
framework::LoDTensor *dx) {
|
||||
int64_t hight = dx->dims()[0];
|
||||
int64_t width = framework::product(dx->dims()) / hight;
|
||||
|
||||
const T *dout_data = dout.data<T>();
|
||||
T *dx_data = dx->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (int64_t h_id = 0; h_id < hight; ++h_id) {
|
||||
T *dst = dx_data + h_id * width;
|
||||
size_t span = ref_lod[h_id + 1] - ref_lod[h_id];
|
||||
for (int64_t w_id = 0; w_id < width; ++w_id) {
|
||||
T result = 0;
|
||||
for (size_t k = 0; k < span; ++k) {
|
||||
size_t offset = (ref_lod[h_id] + k) * width;
|
||||
result += dout_data[offset + w_id];
|
||||
}
|
||||
dst[w_id] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SequenceExpandAsGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *g_out =
|
||||
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
|
||||
auto *y = context.Input<framework::LoDTensor>("Y");
|
||||
auto *g_x =
|
||||
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
||||
|
||||
g_x->mutable_data<T>(context.GetPlace());
|
||||
|
||||
SequenceExpandAsGradFunctor<DeviceContext, T> functor;
|
||||
functor(context.template device_context<DeviceContext>(), *g_out,
|
||||
y->lod()[0], g_x);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestSequenceExpandAs(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'sequence_expand_as'
|
||||
self.set_data()
|
||||
self.compute()
|
||||
|
||||
def set_data(self):
|
||||
x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32')
|
||||
y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32')
|
||||
y_lod = [[1, 3, 4]]
|
||||
self.inputs = {'X': x_data, 'Y': (y_data, y_lod)}
|
||||
|
||||
def compute(self):
|
||||
x = self.inputs['X']
|
||||
x_data, x_lod = x if type(x) == tuple else (x, None)
|
||||
y_data, y_lod = self.inputs['Y']
|
||||
|
||||
assert len(y_lod) == 1 and len(y_lod[0]) == x_data.shape[0]
|
||||
|
||||
repeats = []
|
||||
for i in range(len(y_lod[0])):
|
||||
repeat_num = y_lod[0][i]
|
||||
if repeat_num == 0:
|
||||
continue
|
||||
repeats.extend([i for _ in range(repeat_num)])
|
||||
|
||||
out_data = x_data[repeats]
|
||||
self.outputs = {'Out': (out_data, y_lod)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["X"], "Out")
|
||||
|
||||
|
||||
class TestSequenceExpandAsCase1(TestSequenceExpandAs):
|
||||
def set_data(self):
|
||||
x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32')
|
||||
x_lod = [[2, 3]]
|
||||
y_data = np.random.uniform(0.1, 1, [10, 1]).astype('float32')
|
||||
y_lod = [[2, 2, 0, 3, 3]]
|
||||
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
|
||||
|
||||
|
||||
class TestSequenceExpandAsCase2(TestSequenceExpandAs):
|
||||
def set_data(self):
|
||||
x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32')
|
||||
x_lod = [[1]]
|
||||
y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32')
|
||||
y_lod = [[2]]
|
||||
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue