parent
6665c49299
commit
bc45335e55
@ -0,0 +1,110 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/maxouting.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
// All tensors are in NCHW format
|
||||
template <typename T>
|
||||
class Unpool2d_Max_Functor<platform::CPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
const framework::Tensor& indices,
|
||||
framework::Tensor * output) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output->dims()[1];
|
||||
const int output_height = output->dims()[2];
|
||||
const int output_width = output->dims()[3];
|
||||
|
||||
int input_feasize = input_height * input_width;
|
||||
int output_feasize = output_height * output_width;
|
||||
const T* input_data = input.data<T>();
|
||||
const T* indices_data = indices.data<T>();
|
||||
T* output_data = output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
for (int c = 0; c < output_channels; ++c) {
|
||||
for (int i = 0; i < input_feasize; ++i) {
|
||||
int index = indices_data[i];
|
||||
if(index > output_feasize) {
|
||||
//抛一个异常!
|
||||
}
|
||||
output_data[index] = input_data[i];
|
||||
}
|
||||
input_data += input_feasize;
|
||||
indices_data += input_feasize;
|
||||
output_data += output_feasize;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <class T>
|
||||
class Unpool2d_MaxGradFunctor<platform::CPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
const framework::Tensor& indices,
|
||||
framework::Tensor * input_grad,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& output_grad) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output->dims()[1];
|
||||
const int output_height = output->dims()[2];
|
||||
const int output_width = output->dims()[3];
|
||||
|
||||
int input_feasize = input_height * input_width;
|
||||
int output_feasize = output_height * output_width;
|
||||
const T* input_data = input.data<T>();
|
||||
const T* indices_data = indices.data<T>();
|
||||
const T* output_data = output.data<T>();
|
||||
const T* output_grad_data = output_grad.data<T>();
|
||||
|
||||
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
for (int c = 0; c < output_channels; ++c) {
|
||||
for (int f = 0; f < input_feasize; ++f) {
|
||||
int index = indices_data[i];
|
||||
if(index > output_feasize) {
|
||||
//抛一个异常!
|
||||
}
|
||||
input_grad_data[i] = output_grad_data[index];
|
||||
}
|
||||
input_grad_data += input_feasize;
|
||||
indices_data += input_feasize;
|
||||
output_grad_data += output_feasize;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class Unpool2d_MaxGradFunctor<platform::CPUPlace, float>;
|
||||
template class Unpool2d_MaxGradFunctor<platform::CPUPlace, double>;
|
||||
template class Unpool2d_MaxFunctor<platform::CPUPlace, float>;
|
||||
template class Unpool2d_MaxFunctor<platform::CPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,143 @@
|
||||
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/maxouting.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
__global__ void KernelUnpool2dMax(const int nthreads,
|
||||
const T* input_data,
|
||||
const T* indices_data,
|
||||
const int input_height,
|
||||
const int input_width,
|
||||
T* output_data,
|
||||
const int output_height,
|
||||
const int output_width) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int offset = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < nthreads; i += offset) {
|
||||
int out_offset = i / (input_height * input_width) \
|
||||
* output_height * output_width;
|
||||
int out_index = indices_data[i];
|
||||
output_data[out_offset + out_index] = input_data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
|
||||
const T* input_data,
|
||||
const int input_height,
|
||||
const int input_width,
|
||||
const T* output_data,
|
||||
const T* output_grad,
|
||||
const int output_height,
|
||||
const int output_width,
|
||||
T* input_grad) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int offset = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < nthreads; i += offset) {
|
||||
int out_offset = i / (input_height * input_width) \
|
||||
* output_height * output_width;
|
||||
int out_index = indices_data[i];
|
||||
input_grad[i] = output_grad[out_offset + out_index];
|
||||
}
|
||||
}
|
||||
/*
|
||||
* All tensors are in NCHW format.
|
||||
*/
|
||||
template <typename T>
|
||||
class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
const framework::Tensor& indices,
|
||||
framework::Tensor * output) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output->dims()[1];
|
||||
const int output_height = output->dims()[2];
|
||||
const int output_width = output->dims()[3];
|
||||
int input_feasize = input_height * input_width;
|
||||
int output_feasize = output_height * output_width;
|
||||
const T* input_data = input.data<T>();
|
||||
const T* indices_data = indices.data<T>();
|
||||
T* output_data = output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
int nthreads = output->numel();
|
||||
int blocks = (nthreads + 1024 - 1) / 1024;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blocks, 1);
|
||||
|
||||
KernelUnpool2dMax<
|
||||
T><<<grid, threads, 0,
|
||||
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||
.stream()>>>(nthreads, input_data, indices_data,
|
||||
input_height, input_width,
|
||||
output_data, output_height, output_width);
|
||||
}
|
||||
};
|
||||
/*
|
||||
* All tensors are in NCHW format.
|
||||
*/
|
||||
template <typename T>
|
||||
class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
framework::Tensor * input_grad,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& output_grad,
|
||||
int groups) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output.dims()[1];
|
||||
const int output_height = output.dims()[2];
|
||||
const int output_width = output.dims()[3];
|
||||
|
||||
const T* input_data = input.data<T>();
|
||||
const T* indices_data = indices.data<T>();
|
||||
const T* output_data = output.data<T>();
|
||||
const T* output_grad_data = output_grad.data<T>();
|
||||
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
|
||||
int nthreads = output.numel();
|
||||
int blocks = (nthreads + 1024 - 1) / 1024;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blocks, 1);
|
||||
|
||||
KernelUnpool2dMaxGrad<
|
||||
T><<<grid, threads, 0,
|
||||
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||
.stream()>>>(
|
||||
nthreads, input_data, indices_data,
|
||||
input_height, input_width,
|
||||
output_data, output_grad_data,
|
||||
output_height, output_width,
|
||||
input_grad_data);
|
||||
}
|
||||
};
|
||||
|
||||
template class Unpool2d_MaxGradFunctor<platform::GPUPlace, float>;
|
||||
template class Unpool2d_MaxGradFunctor<platform::GPUPlace, double>;
|
||||
|
||||
template class Unpool2d_MaxFunctor<platform::GPUPlace, float>;
|
||||
template class Unpool2d_MaxFunctor<platform::GPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,48 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/device_context.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
#define FLT_MAX \
|
||||
__FLT_MAX__
|
||||
|
||||
template <typename Place, typename T>
|
||||
|
||||
class Unpool2d_Max_Functor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
const framework::Tensor& indices,
|
||||
framework::Tensor * output);
|
||||
};
|
||||
|
||||
template <typename Place, class T>
|
||||
class Unpool2d_Max_GradFunctor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
framework::Tensor * input_grad,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& output_grad);
|
||||
};
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,116 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/operators/unpool_op.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
UnpoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(Tensor) The input tensor of unpool operator. "
|
||||
"The format of input tensor is NCHW. Where N is batch size, C is the "
|
||||
"number of channels, H and W is the height and width of feature.");
|
||||
AddInput("Y",
|
||||
"(Tensor) The input tensor of the indices given out by MaxPool2d. "
|
||||
"The format of input tensor is NCHW. Where N is batch size, C is the "
|
||||
"number of channels, H and W is the height and width of feature.");
|
||||
AddOutput("Out",
|
||||
"(Tensor) The output tensor of unpool operator."
|
||||
"The format of output tensor is also NCHW."
|
||||
"Where N is batch size, C is "
|
||||
"the number of channels, H and W is the height and "
|
||||
"width of feature.");
|
||||
AddAttr<std::vector<int>>("ksize",
|
||||
"(vector ), the unpooling window size(height, width) "
|
||||
"of unpooling operator.");
|
||||
AddAttr<std::vector<int>>("strides", "(vector, default:{1, 1}), "
|
||||
"strides(height, width) of unpooling operator.")
|
||||
.SetDefault({1, 1});
|
||||
AddAttr<std::vector<int>>("paddings", "(vector defalut:{0,0}), "
|
||||
"paddings(height, width) of unpooling operator.")
|
||||
.SetDefault({0, 0});
|
||||
AddAttr<std::string>("unpoolingType",
|
||||
"(string), unpooling type, can be \"max\" for max-unpooling "
|
||||
"and \"avg\" for average-unpooling.")
|
||||
.InEnum({"max", "avg"});
|
||||
AddComment(R"DOC(
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
int OutputSize(int input_size, int ksize, int padding, int stride) {
|
||||
int output_size = (input_size -1) * stride - 2 * padding + ksize;
|
||||
return output_size;
|
||||
}
|
||||
|
||||
class UnpoolOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp"
|
||||
"should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of UnpoolOp"
|
||||
"should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of UnpoolOp should not be null.");
|
||||
|
||||
auto in_x_dims = ctx->GetInputDim("X");
|
||||
auto in_y_dims = ctx->GetInputDim("Y");
|
||||
std::string unpooling_type = ctx->Attrs().Get<std::string>("unpooling_type");
|
||||
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
|
||||
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
|
||||
|
||||
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
|
||||
"Unpooling intput should be 4-D or 5-D tensor.");
|
||||
|
||||
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
|
||||
for (size_t i = 0; i < ksize.size(); ++i) {
|
||||
output_shape.push_back(
|
||||
OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
|
||||
}
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
||||
}
|
||||
};
|
||||
|
||||
class UnpoolOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(X) must not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Input(X@GRAD) should not be null.");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool2d_grad,
|
||||
ops::UnpoolOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(unpool2d, ops::UnpoolKernel<paddle::platform::CPUPlace,
|
||||
float>);
|
||||
REGISTER_OP_CPU_KERNEL(unpool2d_grad,
|
||||
ops::UnpoolGradKernel<paddle::platform::CPUPlace,
|
||||
float>);
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/unpool_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(unpool2d,
|
||||
ops::UnpoolKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(unpool2d_grad,
|
||||
ops::UnpoolGradKernel<paddle::platform::GPUPlace,
|
||||
float>);
|
@ -0,0 +1,85 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
#include "paddle/operators/math/unpooling.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class UnpoolKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor* in_x = context.Input<Tensor>("X");
|
||||
const Tensor* in_y = context.Input<Tensor>("Y");
|
||||
Tensor* out = context.Output<Tensor>("Out");
|
||||
std::string pooling_type = context.Attr<std::string>("unpooling_type");
|
||||
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
|
||||
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
||||
switch (ksize.size()) {
|
||||
case 2: {
|
||||
if (pooling_type == "max") {
|
||||
math::Unpool2d_Max_Functor<Place, T> unpool2d_max_forward;
|
||||
unpool2d_max_forward(context.device_context(), *in_x, *in_y,
|
||||
ksize, strides, paddings, out);
|
||||
}
|
||||
} break;
|
||||
default: { PADDLE_THROW("Pool op only supports 2D input."); }
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class UnpoolGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor* in_x = context.Input<Tensor>("X");
|
||||
const Tensor* in_y = context.Input<Tensor>("Y");
|
||||
const Tensor* out = context.Input<Tensor>("Out");
|
||||
const Tensor* out_grad =
|
||||
context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
std::string pooling_type = context.Attr<std::string>("unpooling_type");
|
||||
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
|
||||
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
||||
|
||||
auto& device_ctx = context.device_context();
|
||||
math::SetConstant<Place, T> zero;
|
||||
if (in_x_grad) {
|
||||
in_x_grad->mutable_data<T>(context.GetPlace());
|
||||
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
|
||||
}
|
||||
switch (ksize.size()) {
|
||||
case 2: {
|
||||
if (pooling_type == "max") {
|
||||
math::UnpoolGradFunctor<Place, T> maxout_backward;
|
||||
maxout_backward(context.device_context(), *in_x, *in_y, in_x_grad, *out,
|
||||
*out_grad, ksize, strides, paddings);
|
||||
}
|
||||
} break;
|
||||
default: { PADDLE_THROW("Pool op only supports 2D input."); }
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue