commit
9f65b616b2
@ -0,0 +1,31 @@
|
||||
set(INFERENCE_URL "http://paddle-inference-dist.cdn.bcebos.com" CACHE STRING "inference download url")
|
||||
set(INFERENCE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
|
||||
"A path setting inference demo download directories.")
|
||||
function (inference_download install_dir url filename)
|
||||
message(STATUS "Download inference test stuff from ${url}/${filename}")
|
||||
execute_process(COMMAND bash -c "mkdir -p ${install_dir}")
|
||||
execute_process(COMMAND bash -c "cd ${install_dir} && wget -q ${url}/${filename}")
|
||||
message(STATUS "finish downloading ${filename}")
|
||||
endfunction()
|
||||
|
||||
function (inference_download_and_uncompress install_dir url filename)
|
||||
inference_download(${install_dir} ${url} ${filename})
|
||||
execute_process(COMMAND bash -c "cd ${install_dir} && tar xzf ${filename}")
|
||||
endfunction()
|
||||
|
||||
set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec")
|
||||
if (NOT EXISTS ${WORD2VEC_INSTALL_DIR})
|
||||
inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz")
|
||||
endif()
|
||||
set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model")
|
||||
|
||||
function (inference_base_test TARGET)
|
||||
set(options "")
|
||||
set(oneValueArgs "")
|
||||
set(multiValueArgs SRCS ARGS DEPS)
|
||||
cmake_parse_arguments(base_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
if(WITH_GPU)
|
||||
set(mem_opt "--fraction_of_gpu_memory_to_use=0.5")
|
||||
endif()
|
||||
cc_test(${TARGET} SRCS ${base_test_SRCS} DEPS ${base_test_DEPS} ARGS ${mem_opt} ${base_test_ARGS})
|
||||
endfunction()
|
@ -0,0 +1,112 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using ScopedSpatialTransformerDescriptor =
|
||||
platform::ScopedSpatialTransformerDescriptor;
|
||||
|
||||
template <typename T>
|
||||
class CUDNNAffineGridOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use CUDAPlace.");
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto handle = dev_ctx.cudnn_handle();
|
||||
auto* theta = ctx.Input<Tensor>("Theta");
|
||||
auto* output = ctx.Output<Tensor>("Output");
|
||||
const T* theta_data = theta->data<T>();
|
||||
|
||||
int n = theta->dims()[0];
|
||||
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
||||
Tensor h_sizes;
|
||||
int* h_size_data;
|
||||
if (size_attr.size() == 0) {
|
||||
auto* output_shape = ctx.Input<Tensor>("OutputShape");
|
||||
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
|
||||
h_size_data = h_sizes.data<int>();
|
||||
} else {
|
||||
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
|
||||
h_size_data[0] = n;
|
||||
h_size_data[1] = size_attr[1];
|
||||
h_size_data[2] = size_attr[2];
|
||||
h_size_data[3] = size_attr[3];
|
||||
}
|
||||
|
||||
T* output_data = output->mutable_data<T>(
|
||||
{n, h_size_data[2], h_size_data[3], 2}, ctx.GetPlace());
|
||||
ScopedSpatialTransformerDescriptor st_desc;
|
||||
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
|
||||
st_desc.descriptor<T>(4, h_size_data);
|
||||
|
||||
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorForward(
|
||||
handle, cudnn_st_desc, theta_data, output_data));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use CUDAPlace.");
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto handle = dev_ctx.cudnn_handle();
|
||||
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
|
||||
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
|
||||
|
||||
int n = output_grad->dims()[0];
|
||||
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
||||
Tensor h_sizes;
|
||||
int* h_size_data;
|
||||
if (size_attr.size() == 0) {
|
||||
auto* output_shape = ctx.Input<Tensor>("OutputShape");
|
||||
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
|
||||
h_size_data = h_sizes.data<int>();
|
||||
} else {
|
||||
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
|
||||
h_size_data[0] = n;
|
||||
h_size_data[1] = size_attr[1];
|
||||
h_size_data[2] = size_attr[2];
|
||||
h_size_data[3] = size_attr[3];
|
||||
}
|
||||
|
||||
ScopedSpatialTransformerDescriptor st_desc;
|
||||
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
|
||||
st_desc.descriptor<T>(4, h_size_data);
|
||||
|
||||
const T* output_grad_data = output_grad->data<T>();
|
||||
T* theta_grad_data = theta_grad->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorBackward(
|
||||
handle, cudnn_st_desc, output_grad_data, theta_grad_data));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace,
|
||||
paddle::operators::CUDNNAffineGridOpKernel<float>,
|
||||
paddle::operators::CUDNNAffineGridOpKernel<double>);
|
||||
REGISTER_OP_KERNEL(affine_grid_grad, CUDNN, plat::CUDAPlace,
|
||||
paddle::operators::CUDNNAffineGridGradOpKernel<float>,
|
||||
paddle::operators::CUDNNAffineGridGradOpKernel<double>);
|
@ -0,0 +1,233 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/affine_grid_op.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
struct Linspace<paddle::platform::CPUDeviceContext, T> {
|
||||
framework::Tensor operator()(T start, T end, int count,
|
||||
const framework::ExecutionContext& ctx) {
|
||||
Tensor numbers;
|
||||
T* number_data = numbers.mutable_data<T>({count}, platform::CPUPlace());
|
||||
T slice = (end - start) / (T)(count - 1);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
number_data[i] = start + (T)i * slice;
|
||||
}
|
||||
return numbers;
|
||||
}
|
||||
};
|
||||
|
||||
class AffineGridOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Theta"),
|
||||
"Input(Theta) of AffineGridOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Output"),
|
||||
"Output(Output) of AffineGridOp should not be null.");
|
||||
auto theta_dims = ctx->GetInputDim("Theta");
|
||||
PADDLE_ENFORCE(theta_dims.size() == 3,
|
||||
"AffineGrid's Input(Theta) should be 3-D tensor.");
|
||||
|
||||
auto output_shape = ctx->Attrs().Get<std::vector<int>>("output_shape");
|
||||
if (output_shape.size() == 0) {
|
||||
PADDLE_ENFORCE(ctx->HasInput("OutputShape"),
|
||||
"Input(OutputShape) of AffineGridOp should not be null if "
|
||||
"attr(output_shape) is not configured.");
|
||||
auto output_shape_dims = ctx->GetInputDim("OutputShape");
|
||||
PADDLE_ENFORCE(output_shape_dims.size() == 1,
|
||||
"AffineGrid's Input(OutputShape) should be 1-D tensor.");
|
||||
} else {
|
||||
PADDLE_ENFORCE(output_shape.size() == 4,
|
||||
"The size of attr(output_shape) should be 4.");
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE(theta_dims[1] == 2, "Input(theta) dims[1] should be 2.");
|
||||
PADDLE_ENFORCE(theta_dims[2] == 3, "Input(theta) dims[2] should be 3.");
|
||||
// N * H * W * 2
|
||||
ctx->SetOutputDim("Output",
|
||||
framework::make_ddim({theta_dims[0], -1, -1, 2}));
|
||||
ctx->ShareLoD("Theta", "Output");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library{framework::LibraryType::kPlain};
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (platform::CanCUDNNBeUsed(ctx)) {
|
||||
library = framework::LibraryType::kCUDNN;
|
||||
}
|
||||
#endif
|
||||
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type());
|
||||
return framework::OpKernelType(data_type, ctx.GetPlace(),
|
||||
framework::DataLayout::kAnyLayout, library);
|
||||
}
|
||||
};
|
||||
|
||||
class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"Theta",
|
||||
"(Tensor) A batch of affine transform parameters with shape [N, 2, 3]. "
|
||||
"It is used to transform coordinate (x_0, y_0) to coordinate (x_1, "
|
||||
"y_1).");
|
||||
AddInput("OutputShape",
|
||||
"(Tensor) The shape of target image with format [N, C, H, W].")
|
||||
.AsDispensable();
|
||||
AddOutput("Output", "(Tensor) Output Tensor with shape [N, H, W, 2].");
|
||||
AddAttr<bool>(
|
||||
"use_cudnn",
|
||||
"(bool, default false) Only used in cudnn kernel, need install cudnn")
|
||||
.SetDefault(true);
|
||||
AddAttr<std::vector<int>>(
|
||||
"output_shape",
|
||||
"The target output image shape with format [N, C, H, W].")
|
||||
.SetDefault(std::vector<int>());
|
||||
|
||||
AddComment(R"DOC(
|
||||
It generates a grid of (x,y) coordinates using the parameters of the
|
||||
affine transformation that correspond to a set of points where the input
|
||||
feature map should be sampled to produce the transformed output feature map.
|
||||
|
||||
Given:
|
||||
Theta = [[[x_11, x_12, x_13]
|
||||
[x_14, x_15, x_16]]
|
||||
[[x_21, x_22, x_23]
|
||||
[x_24, x_25, x_26]]]
|
||||
|
||||
OutputShape = [2, 3, 5, 5]
|
||||
|
||||
Step 1:
|
||||
|
||||
Generate relative coordinates according to OutputShape.
|
||||
The values of relative coordinates are in the interval between -1 and 1.
|
||||
The shape of the relative coordinates is [2, H, W] as below:
|
||||
|
||||
C = [[[-1. -1. -1. -1. -1. ]
|
||||
[-0.5 -0.5 -0.5 -0.5 -0.5]
|
||||
[ 0. 0. 0. 0. 0. ]
|
||||
[ 0.5 0.5 0.5 0.5 0.5]
|
||||
[ 1. 1. 1. 1. 1. ]]
|
||||
[[-1. -0.5 0. 0.5 1. ]
|
||||
[-1. -0.5 0. 0.5 1. ]
|
||||
[-1. -0.5 0. 0.5 1. ]
|
||||
[-1. -0.5 0. 0.5 1. ]
|
||||
[-1. -0.5 0. 0.5 1. ]]]
|
||||
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
|
||||
|
||||
Step2:
|
||||
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
|
||||
C_ = [[-1. -1. 1. ]
|
||||
[-0.5 -1. 1. ]
|
||||
[ 0. -1. 1. ]
|
||||
[ 0.5 -1. 1. ]
|
||||
[ 1. -1. 1. ]
|
||||
[-1. -0.5 1. ]
|
||||
[-0.5 -0.5 1. ]
|
||||
[ 0. -0.5 1. ]
|
||||
[ 0.5 -0.5 1. ]
|
||||
[ 1. -0.5 1. ]
|
||||
[-1. 0. 1. ]
|
||||
[-0.5 0. 1. ]
|
||||
[ 0. 0. 1. ]
|
||||
[ 0.5 0. 1. ]
|
||||
[ 1. 0. 1. ]
|
||||
[-1. 0.5 1. ]
|
||||
[-0.5 0.5 1. ]
|
||||
[ 0. 0.5 1. ]
|
||||
[ 0.5 0.5 1. ]
|
||||
[ 1. 0.5 1. ]
|
||||
[-1. 1. 1. ]
|
||||
[-0.5 1. 1. ]
|
||||
[ 0. 1. 1. ]
|
||||
[ 0.5 1. 1. ]
|
||||
[ 1. 1. 1. ]]
|
||||
Step3:
|
||||
Compute output by equation $$Output[i] = C_ * Theta[i]^T$$
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class AffineGridOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
auto theta_dims = ctx->GetInputDim("Theta");
|
||||
if (ctx->HasOutput(framework::GradVarName("Theta"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Theta"), theta_dims);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library_{framework::LibraryType::kPlain};
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (platform::CanCUDNNBeUsed(ctx)) {
|
||||
library_ = framework::LibraryType::kCUDNN;
|
||||
}
|
||||
#endif
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()),
|
||||
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
|
||||
}
|
||||
};
|
||||
|
||||
class AffineGridGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
auto* op = new framework::OpDesc();
|
||||
op->SetType("affine_grid_grad");
|
||||
op->SetInput("Theta", Input("Theta"));
|
||||
op->SetInput("OutputShape", Input("OutputShape"));
|
||||
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
|
||||
|
||||
op->SetAttrMap(Attrs());
|
||||
|
||||
op->SetOutput(framework::GradVarName("Theta"), InputGrad("Theta"));
|
||||
return std::unique_ptr<framework::OpDesc>(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(affine_grid, ops::AffineGridOp, ops::AffineGridOpMaker,
|
||||
ops::AffineGridGradMaker);
|
||||
REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
affine_grid,
|
||||
ops::AffineGridOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::AffineGridOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
affine_grid_grad,
|
||||
ops::AffineGridGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::AffineGridGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,190 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
||||
|
||||
using Array1 = Eigen::DSizes<int64_t, 1>;
|
||||
using Array2 = Eigen::DSizes<int64_t, 2>;
|
||||
using Array3 = Eigen::DSizes<int64_t, 3>;
|
||||
using Array4 = Eigen::DSizes<int64_t, 4>;
|
||||
|
||||
/**
|
||||
*Return a tensor with evenly spaced numbers over a specified interval.
|
||||
*/
|
||||
template <typename DeviceContext, typename T>
|
||||
struct Linspace {
|
||||
framework::Tensor operator()(T start, T end, int count,
|
||||
const framework::ExecutionContext& ctx);
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AffineGridOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto* theta = ctx.Input<Tensor>("Theta");
|
||||
int n = theta->dims()[0];
|
||||
|
||||
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
||||
int h = 0;
|
||||
int w = 0;
|
||||
if (size_attr.size() == 0) {
|
||||
auto* output_shape = ctx.Input<Tensor>("OutputShape");
|
||||
Tensor h_sizes;
|
||||
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
|
||||
const int* h_size_data = h_sizes.data<int>();
|
||||
h = h_size_data[2];
|
||||
w = h_size_data[3];
|
||||
} else {
|
||||
h = size_attr[2];
|
||||
w = size_attr[3];
|
||||
}
|
||||
|
||||
auto* output = ctx.Output<Tensor>("Output");
|
||||
output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
|
||||
|
||||
math::SetConstant<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), output,
|
||||
static_cast<T>(0));
|
||||
|
||||
Linspace<DeviceContext, T> linspace;
|
||||
// Get indexes of height with shape [height, width, 1]
|
||||
auto h_idx = linspace((T)-1, (T)1, h, ctx);
|
||||
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
|
||||
// Get indexes of width with shape [height, width, 1]
|
||||
auto w_idx = linspace((T)-1, (T)1, w, ctx);
|
||||
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
|
||||
// Get constant ones tensor with shape [height, width, 1]
|
||||
Tensor ones;
|
||||
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
|
||||
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
|
||||
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
|
||||
// ones
|
||||
Tensor grid;
|
||||
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
|
||||
auto grid_t = EigenTensor<T, 4>::From(grid);
|
||||
|
||||
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
|
||||
.broadcast(Array2(h, 1))
|
||||
.reshape(Array3(h, w, 1))
|
||||
.concatenate(h_idx_t.reshape(Array2(1, h))
|
||||
.broadcast(Array2(w, 1))
|
||||
.shuffle(Array2(1, 0))
|
||||
.reshape(Array3(h, w, 1)),
|
||||
2)
|
||||
.eval()
|
||||
.concatenate(ones_t, 2)
|
||||
.reshape(Array4(1, h, w, 3))
|
||||
.broadcast(Array4(n, 1, 1, 1));
|
||||
|
||||
// output = grid * theta.T
|
||||
// TODO(wanghaoshuang): Refine batched matrix multiply
|
||||
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
|
||||
Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3});
|
||||
Tensor sliced_out = output->Slice(i, i + 1).Resize({h * w, 2});
|
||||
blas.MatMul(sliced_grid, false, sliced_theta, true, T(1), &sliced_out,
|
||||
T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AffineGridGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
|
||||
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
|
||||
|
||||
int n = output_grad->dims()[0];
|
||||
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
|
||||
int h = 0;
|
||||
int w = 0;
|
||||
if (size_attr.size() == 0) {
|
||||
auto* output_shape = ctx.Input<Tensor>("OutputShape");
|
||||
Tensor h_sizes;
|
||||
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
|
||||
const int* h_size_data = h_sizes.data<int>();
|
||||
h = h_size_data[2];
|
||||
w = h_size_data[3];
|
||||
} else {
|
||||
h = size_attr[2];
|
||||
w = size_attr[3];
|
||||
}
|
||||
|
||||
theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
|
||||
|
||||
math::SetConstant<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), theta_grad,
|
||||
static_cast<T>(0));
|
||||
|
||||
Linspace<DeviceContext, T> linspace;
|
||||
|
||||
// Get indexes of height with shape [height, width, 1]
|
||||
auto h_idx = linspace((T)-1, (T)1, h, ctx);
|
||||
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
|
||||
// Get indexes of width with shape [height, width, 1]
|
||||
auto w_idx = linspace((T)-1, (T)1, w, ctx);
|
||||
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
|
||||
// Get constant ones tensor with shape [height, width, 1]
|
||||
Tensor ones;
|
||||
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
|
||||
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
|
||||
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
|
||||
// ones
|
||||
Tensor grid;
|
||||
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
|
||||
auto grid_t = EigenTensor<T, 4>::From(grid);
|
||||
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
|
||||
.broadcast(Array2(h, 1))
|
||||
.reshape(Array3(h, w, 1))
|
||||
.concatenate(h_idx_t.reshape(Array2(1, h))
|
||||
.broadcast(Array2(w, 1))
|
||||
.shuffle(Array2(1, 0))
|
||||
.reshape(Array3(h, w, 1)),
|
||||
2)
|
||||
.eval()
|
||||
.concatenate(ones_t, 2)
|
||||
.reshape(Array4(1, h, w, 3))
|
||||
.broadcast(Array4(n, 1, 1, 1));
|
||||
// output = grid * theta.T
|
||||
// TODO(wanghaoshuang): Refine batched matrix multiply
|
||||
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
|
||||
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize({h * w, 2});
|
||||
Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3});
|
||||
blas.MatMul(sliced_out_grad, true, sliced_grid, false, T(1),
|
||||
&sliced_theta_grad, T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue