parent
d8a21ef6f3
commit
9676ac1c5c
@ -0,0 +1,149 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/flip_op.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::OpKernelType;
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
class FlipOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->HasInput("X"), true,
|
||||||
|
platform::errors::NotFound("Input(X) of FlipOp should not be null."));
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||||
|
platform::errors::NotFound(
|
||||||
|
"Output(Out) of FlipOp should not be null."));
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
auto flip_dims = ctx->Attrs().Get<std::vector<int>>("dims");
|
||||||
|
size_t flip_dims_size = flip_dims.size();
|
||||||
|
|
||||||
|
// check if dims axis within range
|
||||||
|
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
|
||||||
|
PADDLE_ENFORCE_LT(*min_max_d.first, x_dims.size(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"min(dims) should be less than the input tensor X's "
|
||||||
|
"dimensions of FlipOp. But received min(dims) = %d, "
|
||||||
|
"X's dimensions = %d, X's shape = [%s]",
|
||||||
|
*min_max_d.first, x_dims.size(), x_dims));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
*min_max_d.first, x_dims.size() * -1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"min(dims) should be greater than or equal to the input tensor X's "
|
||||||
|
"dimensions of FlipOp times -1. But received min(dims) = %d, X's "
|
||||||
|
"dimensions = %d, X's shape = [%s]",
|
||||||
|
*min_max_d.first, x_dims.size() * -1, x_dims));
|
||||||
|
PADDLE_ENFORCE_LT(*min_max_d.second, x_dims.size(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"max(dims) should be less than the input tensor X's "
|
||||||
|
"dimensions of FlipOp. But received max(dims) = %d, "
|
||||||
|
"X's dimensions = %d, X's shape = [%s]",
|
||||||
|
*min_max_d.second, x_dims.size(), x_dims));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
*min_max_d.second, x_dims.size() * -1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"max(dims) should be greater than or equal to the input tensor X's "
|
||||||
|
"dimensions of FlipOp times -1. But received max(dims) = %d, X's "
|
||||||
|
"dimensions = %d, X's shape = [%s]",
|
||||||
|
*min_max_d.second, x_dims.size() * -1, x_dims));
|
||||||
|
|
||||||
|
// check duplicates in dims
|
||||||
|
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
|
||||||
|
flip_dims.end());
|
||||||
|
PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"dims has duplicates, original flip dims size=%d, "
|
||||||
|
"but unique flip dims size=%d.)",
|
||||||
|
flip_dims_size, flip_dims.size()));
|
||||||
|
|
||||||
|
VLOG(3) << "flip operator x.shape=" << x_dims;
|
||||||
|
|
||||||
|
std::vector<int64_t> output_dims(x_dims.size());
|
||||||
|
for (int i = 0; i < x_dims.size(); ++i) {
|
||||||
|
output_dims[i] = x_dims[i];
|
||||||
|
}
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
|
||||||
|
ctx->ShareLoD("X", "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const {
|
||||||
|
framework::LibraryType library = framework::LibraryType::kPlain;
|
||||||
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
||||||
|
int customized_type_value =
|
||||||
|
framework::OpKernelType::kDefaultCustomizedTypeValue;
|
||||||
|
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||||
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
||||||
|
library, customized_type_value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X", "(Tensor), The input tensor of flip op.");
|
||||||
|
AddOutput("Out", "(Tensor), The output tensor of flip op.");
|
||||||
|
AddAttr<std::vector<int>>("dims", "The axes to flip on.");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Flip Operator.
|
||||||
|
Reverse the order of a n-D tensor along given axis in dims.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
|
||||||
|
protected:
|
||||||
|
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
|
||||||
|
const override {
|
||||||
|
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
|
||||||
|
public:
|
||||||
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Apply(GradOpPtr<T> retv) const override {
|
||||||
|
retv->SetType("flip");
|
||||||
|
retv->SetInput("X", this->OutputGrad("Out"));
|
||||||
|
retv->SetOutput("Out", this->InputGrad("X"));
|
||||||
|
retv->SetAttrMap(this->Attrs());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
|
||||||
|
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
|
||||||
|
ops::FlipOpGradMaker<paddle::imperative::OpBase>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
flip, ops::FlipKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
|
||||||
|
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
|
||||||
|
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
||||||
|
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>);
|
||||||
@ -0,0 +1,166 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/flip_op.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/memory/malloc.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void kernel_pointwise_flip_apply(const int N, const T* in_data,
|
||||||
|
T* out_data, int dim0, int stride0,
|
||||||
|
int dim1, int flip_dim) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
int dst_offset = 0;
|
||||||
|
if (flip_dim == 0) {
|
||||||
|
// flip 1st dim
|
||||||
|
dst_offset = (dim0 - 1 - idx / stride0) * stride0 + idx % stride0;
|
||||||
|
} else {
|
||||||
|
// flip last dim
|
||||||
|
dst_offset = idx / stride0 * stride0 + (dim1 - 1 - idx % stride0);
|
||||||
|
}
|
||||||
|
out_data[dst_offset] = in_data[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data,
|
||||||
|
int64_t* x_shape, int64_t* x_stride,
|
||||||
|
int* flip_dims, int flip_dims_size,
|
||||||
|
int total_dims) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= N) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int cur_indices = idx, rem = 0, dst_offset = 0;
|
||||||
|
for (int i = 0; i < total_dims; ++i) {
|
||||||
|
int64_t temp = cur_indices;
|
||||||
|
cur_indices = cur_indices / x_stride[i];
|
||||||
|
rem = temp - cur_indices * x_stride[i];
|
||||||
|
// flip the indices if it is in flip_dims
|
||||||
|
for (int j = 0; j < flip_dims_size; ++j) {
|
||||||
|
if (i == flip_dims[j]) {
|
||||||
|
cur_indices = x_shape[i] - 1 - cur_indices;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst_offset += cur_indices * x_stride[i];
|
||||||
|
cur_indices = rem;
|
||||||
|
}
|
||||||
|
out_data[idx] = in_data[dst_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FlipKernel<platform::CUDADeviceContext, T>
|
||||||
|
: public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
|
||||||
|
auto cplace = platform::CPUPlace();
|
||||||
|
auto& dev_ctx = ctx.template device_context<CUDADeviceContext>();
|
||||||
|
|
||||||
|
const Tensor* x = ctx.Input<Tensor>("X");
|
||||||
|
Tensor* out = ctx.Output<Tensor>("Out");
|
||||||
|
auto* in_data = x->data<T>();
|
||||||
|
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
|
||||||
|
|
||||||
|
const int flip_dims_size = static_cast<int>(flip_dims.size());
|
||||||
|
auto x_dims = x->dims();
|
||||||
|
const int total_dims = x_dims.size();
|
||||||
|
const int N = x->numel();
|
||||||
|
|
||||||
|
int block_size = 512;
|
||||||
|
dim3 dim_block(block_size);
|
||||||
|
dim3 dim_grid((N + block_size - 1) / block_size);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < flip_dims.size(); ++i) {
|
||||||
|
if (flip_dims[i] < 0) {
|
||||||
|
flip_dims[i] += total_dims;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_stride = framework::stride(x_dims);
|
||||||
|
std::vector<int64_t> x_dims_v = framework::vectorize(x_dims);
|
||||||
|
std::vector<int64_t> x_stride_v = framework::vectorize(x_stride);
|
||||||
|
|
||||||
|
// wrap high-dims to 2-dims
|
||||||
|
if (flip_dims_size == 1 &&
|
||||||
|
(flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
|
||||||
|
int dim0 = 1, dim1 = 1;
|
||||||
|
int stride0 = 1;
|
||||||
|
if (flip_dims[0] == 0) {
|
||||||
|
dim0 = x_dims_v[0];
|
||||||
|
stride0 = x_stride_v[0];
|
||||||
|
for (size_t i = 1; i < total_dims; ++i) {
|
||||||
|
dim1 *= x_dims_v[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dim1 = x_dims_v[total_dims - 1];
|
||||||
|
for (size_t i = 0; i < total_dims - 1; ++i) {
|
||||||
|
dim0 *= x_dims_v[i];
|
||||||
|
}
|
||||||
|
stride0 *= x_dims_v[total_dims - 1];
|
||||||
|
}
|
||||||
|
kernel_pointwise_flip_apply<
|
||||||
|
T><<<dim_grid, dim_block, 0, ctx.cuda_device_context().stream()>>>(
|
||||||
|
N, in_data, out_data, dim0, stride0, dim1, flip_dims[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int bytes = total_dims * sizeof(int64_t);
|
||||||
|
auto x_strides_array_tmp = memory::Alloc(dev_ctx, bytes);
|
||||||
|
int64_t* x_strides_array_gpu =
|
||||||
|
reinterpret_cast<int64_t*>(x_strides_array_tmp->ptr());
|
||||||
|
memory::Copy(gplace, x_strides_array_gpu, cplace, x_stride_v.data(), bytes,
|
||||||
|
dev_ctx.stream());
|
||||||
|
|
||||||
|
auto x_shape_array_tmp = memory::Alloc(dev_ctx, bytes);
|
||||||
|
int64_t* x_shape_array_gpu =
|
||||||
|
reinterpret_cast<int64_t*>(x_shape_array_tmp->ptr());
|
||||||
|
memory::Copy(gplace, x_shape_array_gpu, cplace, x_dims_v.data(), bytes,
|
||||||
|
dev_ctx.stream());
|
||||||
|
|
||||||
|
bytes = flip_dims_size * sizeof(int);
|
||||||
|
auto flip_dims_array_tmp = memory::Alloc(dev_ctx, bytes);
|
||||||
|
int* flip_dims_array_gpu =
|
||||||
|
reinterpret_cast<int*>(flip_dims_array_tmp->ptr());
|
||||||
|
memory::Copy(gplace, flip_dims_array_gpu, cplace, flip_dims.data(), bytes,
|
||||||
|
dev_ctx.stream());
|
||||||
|
|
||||||
|
flip_cuda_kernel<
|
||||||
|
T><<<dim_grid, dim_block, 0, ctx.cuda_device_context().stream()>>>(
|
||||||
|
N, in_data, out_data, x_shape_array_gpu, x_strides_array_gpu,
|
||||||
|
flip_dims_array_gpu, flip_dims_size, total_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
flip, ops::FlipKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::FlipKernel<paddle::platform::CUDADeviceContext, double>,
|
||||||
|
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>,
|
||||||
|
ops::FlipKernel<paddle::platform::CUDADeviceContext, int>,
|
||||||
|
ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>,
|
||||||
|
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>);
|
||||||
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <bitset>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/operator.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
constexpr size_t dim_bitset_size = 64;
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class FlipKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FlipKernel<platform::CPUDeviceContext, T>
|
||||||
|
: public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
const Tensor* x = ctx.Input<Tensor>("X");
|
||||||
|
Tensor* out = ctx.Output<Tensor>("Out");
|
||||||
|
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
|
||||||
|
|
||||||
|
auto x_dims = x->dims();
|
||||||
|
const int total_dims = x_dims.size();
|
||||||
|
std::bitset<dim_bitset_size> dim_bitset;
|
||||||
|
for (size_t i = 0; i < flip_dims.size(); ++i) {
|
||||||
|
int dim = flip_dims[i];
|
||||||
|
if (flip_dims[i] < 0) {
|
||||||
|
dim += total_dims;
|
||||||
|
}
|
||||||
|
dim_bitset[dim] = true;
|
||||||
|
}
|
||||||
|
auto x_strides = framework::stride(x_dims);
|
||||||
|
auto numel = x->numel();
|
||||||
|
const T* x_data = x->data<T>();
|
||||||
|
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
#ifdef PADDLE_WITH_MKLML
|
||||||
|
#pragma omp parallel for
|
||||||
|
#endif
|
||||||
|
for (int64_t i = 0; i < numel; ++i) {
|
||||||
|
int64_t cur_indices = i;
|
||||||
|
int64_t rem = 0;
|
||||||
|
int64_t dst_offset = 0;
|
||||||
|
|
||||||
|
for (int d = 0; d < total_dims; ++d) {
|
||||||
|
int64_t temp = cur_indices;
|
||||||
|
cur_indices = cur_indices / x_strides[d];
|
||||||
|
rem = temp - cur_indices * x_strides[d];
|
||||||
|
dst_offset += dim_bitset[d]
|
||||||
|
? (x_dims[d] - 1 - cur_indices) * x_strides[d]
|
||||||
|
: cur_indices * x_strides[d];
|
||||||
|
cur_indices = rem;
|
||||||
|
}
|
||||||
|
out_data[i] = x_data[dst_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,115 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from paddle.fluid import Program, program_guard
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlipOp_API(unittest.TestCase):
|
||||||
|
"""Test flip api."""
|
||||||
|
|
||||||
|
def test_static_graph(self):
|
||||||
|
startup_program = fluid.Program()
|
||||||
|
train_program = fluid.Program()
|
||||||
|
with fluid.program_guard(train_program, startup_program):
|
||||||
|
dims = [0]
|
||||||
|
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
|
||||||
|
output = paddle.flip(input, dims)
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
if fluid.core.is_compiled_with_cuda():
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup_program)
|
||||||
|
img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
|
||||||
|
res = exe.run(train_program,
|
||||||
|
feed={'input': img},
|
||||||
|
fetch_list=[output])
|
||||||
|
out_np = np.array(res[0])
|
||||||
|
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32)
|
||||||
|
self.assertTrue(
|
||||||
|
(out_np == out_ref).all(),
|
||||||
|
msg='flip output is wrong, out =' + str(out_np))
|
||||||
|
|
||||||
|
def test_dygraph(self):
|
||||||
|
img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
inputs = fluid.dygraph.to_variable(img)
|
||||||
|
ret = paddle.flip(inputs, [0])
|
||||||
|
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32)
|
||||||
|
self.assertTrue(
|
||||||
|
(ret.numpy() == out_ref).all(),
|
||||||
|
msg='flip output is wrong, out =' + str(ret.numpy()))
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlipOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = 'flip'
|
||||||
|
self.init_test_case()
|
||||||
|
self.inputs = {'X': np.random.random(self.in_shape).astype('float64')}
|
||||||
|
self.init_attrs()
|
||||||
|
self.outputs = {'Out': self.calc_ref_res()}
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
self.attrs = {"dims": self.dims}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(["X"], "Out")
|
||||||
|
|
||||||
|
def init_test_case(self):
|
||||||
|
self.in_shape = (6, 4, 2, 3)
|
||||||
|
self.dims = [0, 1]
|
||||||
|
|
||||||
|
def calc_ref_res(self):
|
||||||
|
res = self.inputs['X']
|
||||||
|
for axis in self.dims:
|
||||||
|
res = np.flip(res, axis)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlipOpAxis1(TestFlipOp):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.in_shape = (2, 4, 4)
|
||||||
|
self.dims = [0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlipOpAxis2(TestFlipOp):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.in_shape = (4, 4, 6, 3)
|
||||||
|
self.dims = [0, 2]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlipOpAxis3(TestFlipOp):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.in_shape = (4, 3, 1)
|
||||||
|
self.dims = [0, 1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlipOpAxis4(TestFlipOp):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.in_shape = (6, 4, 2, 2)
|
||||||
|
self.dims = [0, 1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue