commit
b548ecbc2b
@ -0,0 +1,48 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include "paddle/fluid/platform/hostdevice.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
template <typename T, size_t N>
|
||||||
|
class Array {
|
||||||
|
static_assert(N > 0, "The size of array must be larger than 0");
|
||||||
|
|
||||||
|
public:
|
||||||
|
HOSTDEVICE Array() {}
|
||||||
|
|
||||||
|
HOSTDEVICE explicit Array(const T &val) {
|
||||||
|
for (size_t i = 0; i < N; ++i) data_[i] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
HOSTDEVICE const T *Get() const { return data_; }
|
||||||
|
|
||||||
|
HOSTDEVICE T *GetMutable() { return data_; }
|
||||||
|
|
||||||
|
HOSTDEVICE T &operator[](size_t index) { return data_[index]; }
|
||||||
|
|
||||||
|
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; }
|
||||||
|
|
||||||
|
HOSTDEVICE constexpr size_t size() const { return N; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
T data_[N];
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/stack_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
struct CPUStackFunctor {
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
void operator()(const DeviceContext& ctx, const std::vector<const T*>& x,
|
||||||
|
T* y, int pre, int n, int post) const {
|
||||||
|
int total_num = pre * post * n;
|
||||||
|
for (int idx = 0; idx < total_num; ++idx) {
|
||||||
|
int i = idx / (n * post);
|
||||||
|
int which_x = idx / post - i * n;
|
||||||
|
int x_index = i * post + idx % post;
|
||||||
|
y[idx] = x[which_x][x_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CPUStackGradFunctor {
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
void operator()(const DeviceContext& ctx, std::vector<T*>& dx, // NOLINT
|
||||||
|
const T* dy, int pre, int n, int post) const {
|
||||||
|
int total_num = pre * post * n;
|
||||||
|
for (int idx = 0; idx < total_num; ++idx) {
|
||||||
|
int i = idx / (n * post);
|
||||||
|
int which_x = idx / post - i * n;
|
||||||
|
int x_index = i * post + idx % post;
|
||||||
|
dx[which_x][x_index] = dy[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
|
||||||
|
ops::StackGradOpDescMaker);
|
||||||
|
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
stack,
|
||||||
|
ops::StackKernel<plat::CPUDeviceContext, float, ops::CPUStackFunctor>,
|
||||||
|
ops::StackKernel<plat::CPUDeviceContext, double, ops::CPUStackFunctor>);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(stack_grad,
|
||||||
|
ops::StackGradKernel<plat::CPUDeviceContext, float,
|
||||||
|
ops::CPUStackGradFunctor>,
|
||||||
|
ops::StackGradKernel<plat::CPUDeviceContext, double,
|
||||||
|
ops::CPUStackGradFunctor>);
|
@ -0,0 +1,109 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <thrust/device_vector.h>
|
||||||
|
#include "paddle/fluid/framework/array.h"
|
||||||
|
#include "paddle/fluid/operators/stack_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T, typename VecXType>
|
||||||
|
__global__ void StackCUDAKernel(VecXType x, T* y, int total_num, int n,
|
||||||
|
int post) {
|
||||||
|
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < total_num) {
|
||||||
|
int i = idx / (n * post);
|
||||||
|
int which_x = idx / post - i * n;
|
||||||
|
int x_index = i * post + idx % post;
|
||||||
|
y[idx] = x[which_x][x_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VecDxType>
|
||||||
|
__global__ void StackGradCUDAKernel(VecDxType dx, const T* dy, int total_num,
|
||||||
|
int n, int post) {
|
||||||
|
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < total_num) {
|
||||||
|
int i = idx / (n * post);
|
||||||
|
int which_x = idx / post - i * n;
|
||||||
|
int x_index = i * post + idx % post;
|
||||||
|
dx[which_x][x_index] = dy[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GPUStackFunctor {
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
void operator()(const DeviceContext& ctx, const std::vector<const T*>& x,
|
||||||
|
T* y, int pre, int n, int post) const {
|
||||||
|
int total_num = pre * post * n;
|
||||||
|
int threads = 512;
|
||||||
|
int grid = (total_num + threads - 1) / threads;
|
||||||
|
|
||||||
|
constexpr auto kMaxThreshold = 16;
|
||||||
|
if (n <= kMaxThreshold) {
|
||||||
|
framework::Array<const T*, kMaxThreshold> arr;
|
||||||
|
for (int i = 0; i < n; ++i) arr[i] = x[i];
|
||||||
|
StackCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(arr, y, total_num, n,
|
||||||
|
post);
|
||||||
|
} else {
|
||||||
|
VLOG(10) << "Stack more than " << kMaxThreshold
|
||||||
|
<< " tensors may be slow on GPU.";
|
||||||
|
thrust::device_vector<const T*> dev_x(x);
|
||||||
|
StackCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(dev_x.data().get(), y,
|
||||||
|
total_num, n, post);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GPUStackGradFunctor {
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
void operator()(const DeviceContext& ctx, std::vector<T*>& dx, // NOLINT
|
||||||
|
const T* dy, int pre, int n, int post) const {
|
||||||
|
int total_num = pre * post * n;
|
||||||
|
int threads = 512;
|
||||||
|
int grid = (total_num + threads - 1) / threads;
|
||||||
|
|
||||||
|
constexpr auto kMaxThreshold = 16;
|
||||||
|
if (n <= kMaxThreshold) {
|
||||||
|
framework::Array<T*, kMaxThreshold> arr;
|
||||||
|
for (int i = 0; i < n; ++i) arr[i] = dx[i];
|
||||||
|
StackGradCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(
|
||||||
|
arr, dy, total_num, n, post);
|
||||||
|
} else {
|
||||||
|
VLOG(10) << "Stack more than " << kMaxThreshold
|
||||||
|
<< " tensors may be slow on GPU.";
|
||||||
|
thrust::device_vector<T*> dev_dx(dx);
|
||||||
|
StackGradCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(
|
||||||
|
dev_dx.data().get(), dy, total_num, n, post);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
stack,
|
||||||
|
ops::StackKernel<plat::CUDADeviceContext, float, ops::GPUStackFunctor>,
|
||||||
|
ops::StackKernel<plat::CUDADeviceContext, double, ops::GPUStackFunctor>);
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(stack_grad,
|
||||||
|
ops::StackGradKernel<plat::CUDADeviceContext, float,
|
||||||
|
ops::GPUStackGradFunctor>,
|
||||||
|
ops::StackGradKernel<plat::CUDADeviceContext, double,
|
||||||
|
ops::GPUStackGradFunctor>);
|
@ -0,0 +1,192 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
inline void GetPrePostForStackOp(const framework::DDim &dim, int axis, int *pre,
|
||||||
|
int *post) {
|
||||||
|
*pre = 1;
|
||||||
|
for (auto i = 0; i < axis; ++i) (*pre) *= dim[i];
|
||||||
|
*post = 1;
|
||||||
|
for (auto i = axis; i < dim.size(); ++i) (*post) *= dim[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
class StackOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
|
||||||
|
"Number of Inputs(X) must be larger than 0");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist.");
|
||||||
|
|
||||||
|
auto input_dims = ctx->GetInputsDim("X");
|
||||||
|
for (size_t i = 1; i < input_dims.size(); ++i) {
|
||||||
|
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
|
||||||
|
"Dims of all Inputs(X) must be the same");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only lod of X[0] would be shared with Y
|
||||||
|
ctx->ShareLoD("X", /*->*/ "Y");
|
||||||
|
|
||||||
|
int axis = ctx->Attrs().Get<int>("axis");
|
||||||
|
int rank = input_dims[0].size();
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
axis >= -(rank + 1) && axis < rank + 1,
|
||||||
|
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
|
||||||
|
if (axis < 0) axis += (rank + 1);
|
||||||
|
|
||||||
|
auto vec = framework::vectorize2int(input_dims[0]);
|
||||||
|
vec.insert(vec.begin() + axis, input_dims.size());
|
||||||
|
ctx->SetOutputDim("Y", framework::make_ddim(vec));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X", "The input of stack op.").AsDuplicable();
|
||||||
|
AddOutput("Y", "The output of stack op.");
|
||||||
|
AddAttr<int>("axis",
|
||||||
|
"The axis along which all of the Inputs(X) should be stacked.")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Stack Operator.
|
||||||
|
|
||||||
|
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T, typename Functor>
|
||||||
|
class StackKernel : public framework::OpKernel<T> {
|
||||||
|
using Tensor = framework::LoDTensor;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||||
|
auto x = ctx.MultiInput<Tensor>("X");
|
||||||
|
auto *y = ctx.Output<Tensor>("Y");
|
||||||
|
|
||||||
|
int axis = ctx.Attr<int>("axis");
|
||||||
|
if (axis < 0) axis += (x[0]->dims().size() + 1);
|
||||||
|
|
||||||
|
int n = static_cast<int>(x.size());
|
||||||
|
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
|
||||||
|
std::vector<const T *> x_datas(n);
|
||||||
|
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
|
||||||
|
|
||||||
|
int pre = 1, post = 1;
|
||||||
|
auto &dim = x[0]->dims();
|
||||||
|
for (auto i = 0; i < axis; ++i) pre *= dim[i];
|
||||||
|
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
|
||||||
|
|
||||||
|
Functor functor;
|
||||||
|
functor(ctx.template device_context<DeviceContext>(), x_datas, y_data, pre,
|
||||||
|
n, post);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class StackOpGrad : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
||||||
|
"Input(Y@Grad) must exist.");
|
||||||
|
|
||||||
|
int axis = ctx->Attrs().Get<int>("axis");
|
||||||
|
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
|
||||||
|
int rank = dy_dim.size();
|
||||||
|
PADDLE_ENFORCE(axis >= -rank && axis < rank,
|
||||||
|
"Attr(axis) must be inside [-rank, rank), where rank = %d",
|
||||||
|
rank);
|
||||||
|
if (axis < 0) axis += rank;
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(),
|
||||||
|
static_cast<size_t>(dy_dim[axis]),
|
||||||
|
"Number of Outputs(X@Grad) is wrong");
|
||||||
|
auto vec = framework::vectorize2int(dy_dim);
|
||||||
|
vec.erase(vec.begin() + axis);
|
||||||
|
ctx->SetOutputsDim(
|
||||||
|
framework::GradVarName("X"),
|
||||||
|
std::vector<framework::DDim>(dy_dim[axis], framework::make_ddim(vec)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class StackGradOpDescMaker
|
||||||
|
: public framework::
|
||||||
|
SingleGradOpDescMaker /*framework::GradOpDescMakerBase*/ {
|
||||||
|
public:
|
||||||
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||||
|
/*
|
||||||
|
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<framework::OpDesc>> operator ()() const override {
|
||||||
|
auto x_grads = InputGrad("X", false);
|
||||||
|
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
|
||||||
|
grad_ops.reserve(x_grads.size());
|
||||||
|
auto og = OutputGrad("Y");
|
||||||
|
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
|
||||||
|
[&og](const std::string& x_grad) {
|
||||||
|
auto* grad_op = new framework::OpDesc();
|
||||||
|
grad_op->SetInput("X", og);
|
||||||
|
grad_op->SetOutput("Y", {x_grad});
|
||||||
|
grad_op->SetAttrMap(Attrs());
|
||||||
|
return std::unique_ptr<framework::OpDesc>(grad_op);
|
||||||
|
});
|
||||||
|
return grad_ops;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||||
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
||||||
|
op->SetType("stack_grad");
|
||||||
|
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
|
||||||
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
|
||||||
|
op->SetAttrMap(Attrs());
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T, typename GradFunctor>
|
||||||
|
class StackGradKernel : public framework::OpKernel<T> {
|
||||||
|
using Tensor = framework::LoDTensor;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||||
|
auto *dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||||
|
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
|
||||||
|
int axis = ctx.Attr<int>("axis");
|
||||||
|
if (axis < 0) axis += dy->dims().size();
|
||||||
|
|
||||||
|
int n = dy->dims()[axis];
|
||||||
|
std::vector<T *> dx_datas(n); // NOLINT
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
|
||||||
|
auto dy_data = dy->data<T>();
|
||||||
|
|
||||||
|
int pre = 1;
|
||||||
|
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
|
||||||
|
int post = dy->numel() / (n * pre);
|
||||||
|
GradFunctor functor;
|
||||||
|
functor(ctx.template device_context<DeviceContext>(), dx_datas, dy_data,
|
||||||
|
pre, n, post);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from op_test import OpTest
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackOpBase(OpTest):
|
||||||
|
def initDefaultParameters(self):
|
||||||
|
self.num_inputs = 4
|
||||||
|
self.input_dim = (5, 6, 7)
|
||||||
|
self.axis = 0
|
||||||
|
self.dtype = 'float32'
|
||||||
|
|
||||||
|
def initParameters(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_x_names(self):
|
||||||
|
x_names = []
|
||||||
|
for i in range(self.num_inputs):
|
||||||
|
x_names.append('x{}'.format(i))
|
||||||
|
return x_names
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.initDefaultParameters()
|
||||||
|
self.initParameters()
|
||||||
|
self.op_type = 'stack'
|
||||||
|
self.x = []
|
||||||
|
for i in range(self.num_inputs):
|
||||||
|
self.x.append(
|
||||||
|
np.random.random(size=self.input_dim).astype(self.dtype))
|
||||||
|
|
||||||
|
tmp = []
|
||||||
|
x_names = self.get_x_names()
|
||||||
|
for i in range(self.num_inputs):
|
||||||
|
tmp.append((x_names[i], self.x[i]))
|
||||||
|
|
||||||
|
self.inputs = {'X': tmp}
|
||||||
|
self.outputs = {'Y': np.stack(self.x, axis=self.axis)}
|
||||||
|
self.attrs = {'axis': self.axis}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(self.get_x_names(), 'Y')
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackOp1(TestStackOpBase):
|
||||||
|
def initParameters(self):
|
||||||
|
self.num_inputs = 16
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackOp2(TestStackOpBase):
|
||||||
|
def initParameters(self):
|
||||||
|
self.num_inputs = 20
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackOp3(TestStackOpBase):
|
||||||
|
def initParameters(self):
|
||||||
|
self.axis = -1
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackOp4(TestStackOpBase):
|
||||||
|
def initParameters(self):
|
||||||
|
self.axis = -4
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackOp5(TestStackOpBase):
|
||||||
|
def initParameters(self):
|
||||||
|
self.axis = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestStackOp6(TestStackOpBase):
|
||||||
|
def initParameters(self):
|
||||||
|
self.axis = 3
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue