commit
698698f2fa
@ -1,207 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/bilinear_interp_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeBilinearInterpFw(
|
||||
const T* in, const size_t in_img_h, const size_t in_img_w,
|
||||
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
|
||||
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
||||
const size_t num_channels, const T ratio_h, const T ratioW) {
|
||||
int nthreads = output_h * output_w;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < nthreads) {
|
||||
int out_id_h = tid / output_w;
|
||||
int out_id_w = tid % output_w;
|
||||
int in_img_size = input_w / num_channels;
|
||||
int out_img_size = output_w / num_channels;
|
||||
int channel_id = out_id_w / out_img_size;
|
||||
|
||||
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
||||
int in_img_idy = ratio_h * out_img_idy;
|
||||
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
|
||||
T h1lambda = ratio_h * out_img_idy - in_img_idy;
|
||||
T h2lambda = 1.f - h1lambda;
|
||||
|
||||
int out_img_idx = tid % out_img_w;
|
||||
int in_img_idx = ratioW * out_img_idx;
|
||||
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
|
||||
T w1lambda = ratioW * out_img_idx - in_img_idx;
|
||||
T w2lambda = 1.f - w1lambda;
|
||||
|
||||
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
||||
in_img_idy * in_img_w + in_img_idx];
|
||||
|
||||
// bilinear interpolation
|
||||
out[out_id_h * output_w + out_id_w] =
|
||||
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
|
||||
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
|
||||
w1lambda * in_pos[h_id * in_img_w + w_id]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeBilinearInterpBw(
|
||||
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
|
||||
const size_t input_w, const T* out, const size_t out_img_h,
|
||||
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
||||
const size_t num_channels, const T ratio_h, const T ratioW) {
|
||||
int nthreads = output_h * output_w;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < nthreads) {
|
||||
int out_id_h = tid / output_w;
|
||||
int out_id_w = tid % output_w;
|
||||
int in_img_size = input_w / num_channels;
|
||||
int out_img_size = output_w / num_channels;
|
||||
int channel_id = out_id_w / out_img_size;
|
||||
|
||||
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
||||
int in_img_idy = ratio_h * out_img_idy;
|
||||
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
|
||||
T h1lambda = ratio_h * out_img_idy - in_img_idy;
|
||||
T h2lambda = 1.f - h1lambda;
|
||||
|
||||
int out_img_idx = tid % out_img_w;
|
||||
int in_img_idx = ratioW * out_img_idx;
|
||||
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
|
||||
T w1lambda = ratioW * out_img_idx - in_img_idx;
|
||||
T w2lambda = 1.f - w1lambda;
|
||||
|
||||
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
||||
in_img_idy * in_img_w + in_img_idx];
|
||||
const T* out_pos = &out[out_id_h * output_w + out_id_w];
|
||||
atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
|
||||
atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
|
||||
atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]);
|
||||
atomicAdd(&in_pos[h_id * in_img_w + w_id],
|
||||
h1lambda * w1lambda * out_pos[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
|
||||
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
|
||||
auto* input = input_t->data<T>();
|
||||
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
auto out_dims = output_t->dims();
|
||||
auto out_size_t = ctx.Input<Tensor>("OutSize");
|
||||
if (out_size_t != nullptr) {
|
||||
Tensor sizes;
|
||||
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
|
||||
auto size_data = sizes.data<int>();
|
||||
out_h = size_data[0];
|
||||
out_w = size_data[1];
|
||||
}
|
||||
auto* output = output_t->mutable_data<T>(
|
||||
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
|
||||
|
||||
int batch_size = input_t->dims()[0];
|
||||
int channels = input_t->dims()[1];
|
||||
int in_h = input_t->dims()[2];
|
||||
int in_w = input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(output, input, input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
int threadNum = batch_size * out_chw;
|
||||
int blocks = (threadNum + 1024 - 1) / 1024;
|
||||
|
||||
KeBilinearInterpFw<
|
||||
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
||||
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
|
||||
batch_size, out_chw, channels, ratio_h, ratio_w);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* d_output = d_output_t->data<T>();
|
||||
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto& device_ctx =
|
||||
ctx.template device_context<platform::CUDADeviceContext>();
|
||||
math::SetConstant<platform::CUDADeviceContext, T> zero;
|
||||
zero(device_ctx, d_input_t, static_cast<T>(0.0));
|
||||
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
|
||||
auto out_size_t = ctx.Input<Tensor>("OutSize");
|
||||
if (out_size_t != nullptr) {
|
||||
Tensor sizes;
|
||||
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
|
||||
auto size_data = sizes.data<int>();
|
||||
out_h = size_data[0];
|
||||
out_w = size_data[1];
|
||||
}
|
||||
|
||||
int batch_size = d_input_t->dims()[0];
|
||||
int channels = d_input_t->dims()[1];
|
||||
int in_h = d_input_t->dims()[2];
|
||||
int in_w = d_input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
int threadNum = batch_size * out_chw;
|
||||
int blocks = (threadNum + 1024 - 1) / 1024;
|
||||
|
||||
KeBilinearInterpBw<
|
||||
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
||||
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
|
||||
batch_size, out_chw, channels, ratio_h, ratio_w);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
|
||||
ops::BilinearInterpOpCUDAKernel<float>);
|
||||
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
|
||||
ops::BilinearInterpGradOpCUDAKernel<float>);
|
@ -1,163 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
|
||||
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
|
||||
auto out_dims = output_t->dims();
|
||||
auto* input = input_t->data<T>();
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
auto out_size_t = ctx.Input<Tensor>("OutSize");
|
||||
if (out_size_t != nullptr) {
|
||||
auto out_size_data = out_size_t->data<int>();
|
||||
out_h = out_size_data[0];
|
||||
out_w = out_size_data[1];
|
||||
}
|
||||
auto* output = output_t->mutable_data<T>(
|
||||
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
|
||||
int batch_size = input_t->dims()[0];
|
||||
int channels = input_t->dims()[1];
|
||||
int in_h = input_t->dims()[2];
|
||||
int in_w = input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
float ratio_h =
|
||||
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
float ratio_w =
|
||||
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(output, input, input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
for (int k = 0; k < batch_size; ++k) { // loop for batches
|
||||
for (int i = 0; i < out_h; ++i) { // loop for images
|
||||
int h = ratio_h * i;
|
||||
int hid = (h < in_h - 1) ? 1 : 0;
|
||||
float h1lambda = ratio_h * i - h;
|
||||
float h2lambda = 1.f - h1lambda;
|
||||
|
||||
for (int j = 0; j < out_w; ++j) {
|
||||
int w = ratio_w * j;
|
||||
int wid = (w < in_w - 1) ? 1 : 0;
|
||||
float w1lambda = ratio_w * j - w;
|
||||
float w2lambda = 1.f - w1lambda;
|
||||
// calculate four position for bilinear interpolation
|
||||
const T* in_pos = &input[k * in_chw + h * in_w + w];
|
||||
T* out_pos = &output[k * out_chw + i * out_w + j];
|
||||
|
||||
for (int c = 0; c < channels; ++c) { // loop for channels
|
||||
// bilinear interpolation
|
||||
out_pos[0] = static_cast<T>(
|
||||
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
|
||||
h1lambda * (w2lambda * in_pos[hid * in_w] +
|
||||
w1lambda * in_pos[hid * in_w + wid]));
|
||||
in_pos += in_hw;
|
||||
out_pos += out_hw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BilinearInterpGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* d_output = d_output_t->data<T>();
|
||||
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
|
||||
auto& device_ctx =
|
||||
ctx.template device_context<platform::CPUDeviceContext>();
|
||||
math::SetConstant<platform::CPUDeviceContext, T> zero;
|
||||
zero(device_ctx, d_input_t, static_cast<T>(0.0));
|
||||
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
|
||||
auto out_size_t = ctx.Input<Tensor>("OutSize");
|
||||
if (out_size_t != nullptr) {
|
||||
auto out_size_data = out_size_t->data<int>();
|
||||
out_h = out_size_data[0];
|
||||
out_w = out_size_data[1];
|
||||
}
|
||||
|
||||
int batch_size = d_input_t->dims()[0];
|
||||
int channels = d_input_t->dims()[1];
|
||||
int in_h = d_input_t->dims()[2];
|
||||
int in_w = d_input_t->dims()[3];
|
||||
|
||||
int in_hw = in_h * in_w;
|
||||
int out_hw = out_h * out_w;
|
||||
int in_chw = channels * in_hw;
|
||||
int out_chw = channels * out_hw;
|
||||
|
||||
float ratio_h =
|
||||
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
float ratio_w =
|
||||
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
|
||||
} else {
|
||||
for (int k = 0; k < batch_size; ++k) { // loop for batches
|
||||
for (int i = 0; i < out_h; ++i) { // loop for images
|
||||
int h = ratio_h * i;
|
||||
int hid = (h < in_h - 1) ? 1 : 0;
|
||||
float h1lambda = ratio_h * i - h;
|
||||
float h2lambda = 1 - h1lambda;
|
||||
|
||||
for (int j = 0; j < out_w; ++j) {
|
||||
int w = ratio_w * j;
|
||||
int wid = (w < in_w - 1) ? 1 : 0;
|
||||
float w1lambda = ratio_w * j - w;
|
||||
float w2lambda = 1 - w1lambda;
|
||||
T* in_pos = &d_input[k * in_chw + h * in_w + w];
|
||||
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
|
||||
|
||||
for (int c = 0; c < channels; ++c) { // loop for channels
|
||||
in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
|
||||
in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
|
||||
in_pos[hid * in_w] +=
|
||||
static_cast<T>(h1lambda * w2lambda * out_pos[0]);
|
||||
in_pos[hid * in_w + wid] +=
|
||||
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
|
||||
in_pos += in_hw;
|
||||
out_pos += out_hw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,90 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename TAlgorithm>
|
||||
class AlgorithmsCache {
|
||||
public:
|
||||
// Caches the best algorithm for a given
|
||||
// combination of tensor dimensions & compute data type.
|
||||
TAlgorithm GetAlgorithm(
|
||||
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
|
||||
const std::vector<int>& strides, const std::vector<int>& paddings,
|
||||
const std::vector<int>& dilations,
|
||||
int algorithmFlags, // can set for different data type
|
||||
std::function<TAlgorithm()> gen_func);
|
||||
|
||||
private:
|
||||
std::unordered_map<int64_t, TAlgorithm> hash_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
template <typename TAlgorithm>
|
||||
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
|
||||
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
|
||||
const std::vector<int>& strides, const std::vector<int>& paddings,
|
||||
const std::vector<int>& dilations, int algorithmFlags,
|
||||
std::function<TAlgorithm()> gen_func) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
int64_t seed = 0;
|
||||
// Hash all of the inputs, use to try and look up a previously
|
||||
// discovered algorithm, or fall back to generating a new one.
|
||||
std::hash<int64_t> hashFn;
|
||||
// do hash like boost
|
||||
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
|
||||
for (const auto num : dims1) {
|
||||
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
for (const auto num : dims2) {
|
||||
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
|
||||
}
|
||||
|
||||
for (const auto num : strides) {
|
||||
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
|
||||
(seed >> 2) + 2;
|
||||
}
|
||||
|
||||
for (const auto num : paddings) {
|
||||
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
|
||||
(seed >> 2) + 3;
|
||||
}
|
||||
|
||||
for (const auto num : dilations) {
|
||||
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
|
||||
(seed >> 2) + 4;
|
||||
}
|
||||
|
||||
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
|
||||
(seed << 6) + (seed >> 2) + 5;
|
||||
|
||||
if (seed == 0) return gen_func();
|
||||
|
||||
if (hash_.find(seed) == hash_.end()) {
|
||||
TAlgorithm value = gen_func();
|
||||
hash_[seed] = value;
|
||||
}
|
||||
return hash_[seed];
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,236 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
|
||||
const float ratio_h, const float ratio_w,
|
||||
const int n, const int c,
|
||||
const int out_h, const int out_w) {
|
||||
auto input_t = EigenTensor<T, 4>::From(input);
|
||||
auto output_t = EigenTensor<T, 4>::From(*output);
|
||||
for (int k = 0; k < out_h; k++) { // loop for images
|
||||
int in_k = static_cast<int>(ratio_h * k + 0.5);
|
||||
|
||||
for (int l = 0; l < out_w; l++) {
|
||||
int in_l = static_cast<int>(ratio_w * l + 0.5);
|
||||
|
||||
for (int i = 0; i < n; i++) { // loop for batches
|
||||
for (int j = 0; j < c; j++) { // loop for channels
|
||||
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void BilinearInterpolation(const Tensor& input, Tensor* output,
|
||||
const float ratio_h, const float ratio_w,
|
||||
const int in_h, const int in_w, const int n,
|
||||
const int c, const int out_h,
|
||||
const int out_w) {
|
||||
auto input_t = EigenTensor<T, 4>::From(input);
|
||||
auto output_t = EigenTensor<T, 4>::From(*output);
|
||||
for (int k = 0; k < out_h; k++) { // loop for images
|
||||
int y_n = static_cast<int>(ratio_h * k);
|
||||
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
|
||||
float d_n = ratio_h * k - y_n;
|
||||
float d_s = 1.f - d_n;
|
||||
|
||||
for (int l = 0; l < out_w; l++) {
|
||||
int x_w = static_cast<int>(ratio_w * l);
|
||||
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
|
||||
float d_w = ratio_w * l - x_w;
|
||||
float d_e = 1.f - d_w;
|
||||
|
||||
for (int i = 0; i < n; i++) { // loop for batches
|
||||
for (int j = 0; j < c; j++) { // loop for channels
|
||||
// bilinear interpolation
|
||||
output_t(i, j, k, l) = input_t(i, j, y_n, x_w) * d_s * d_e +
|
||||
input_t(i, j, y_s, x_w) * d_n * d_e +
|
||||
input_t(i, j, y_n, x_e) * d_s * d_w +
|
||||
input_t(i, j, y_s, x_e) * d_n * d_w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void NearestNeighborInterpolateGrad(const Tensor& output_grad,
|
||||
Tensor* input_grad,
|
||||
const float ratio_h,
|
||||
const float ratio_w, const int n,
|
||||
const int c, const int out_h,
|
||||
const int out_w) {
|
||||
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
|
||||
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
|
||||
for (int k = 0; k < out_h; k++) { // loop for images
|
||||
int in_k = static_cast<int>(ratio_h * k + 0.5);
|
||||
|
||||
for (int l = 0; l < out_w; l++) {
|
||||
int in_l = static_cast<int>(ratio_w * l + 0.5);
|
||||
|
||||
for (int i = 0; i < n; i++) { // loop for batches
|
||||
for (int j = 0; j < c; j++) { // loop for channels
|
||||
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void BilinearInterpolationGrad(const Tensor& output_grad,
|
||||
Tensor* input_grad, const float ratio_h,
|
||||
const float ratio_w, const int in_h,
|
||||
const int in_w, const int n, const int c,
|
||||
const int out_h, const int out_w) {
|
||||
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
|
||||
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
|
||||
for (int k = 0; k < out_h; k++) { // loop for images
|
||||
int y_n = static_cast<int>(ratio_h * k);
|
||||
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
|
||||
float d_n = ratio_h * k - y_n;
|
||||
float d_s = 1.f - d_n;
|
||||
|
||||
for (int l = 0; l < out_w; l++) {
|
||||
int x_w = static_cast<int>(ratio_w * l);
|
||||
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
|
||||
float d_w = ratio_w * l - x_w;
|
||||
float d_e = 1.f - d_w;
|
||||
|
||||
for (int i = 0; i < n; i++) { // loop for batches
|
||||
for (int j = 0; j < c; j++) { // loop for channels
|
||||
// bilinear interpolation grad
|
||||
const T grad = output_grad_t(i, j, k, l);
|
||||
input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
|
||||
input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
|
||||
input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
|
||||
input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class InterpolateKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* output = ctx.Output<Tensor>("Out");
|
||||
|
||||
std::string interp_method = ctx.Attr<std::string>("interp_method");
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
auto out_size = ctx.Input<Tensor>("OutSize");
|
||||
if (out_size != nullptr) {
|
||||
auto out_size_data = out_size->data<int>();
|
||||
out_h = out_size_data[0];
|
||||
out_w = out_size_data[1];
|
||||
}
|
||||
|
||||
const int n = input->dims()[0];
|
||||
const int c = input->dims()[1];
|
||||
const int in_h = input->dims()[2];
|
||||
const int in_w = input->dims()[3];
|
||||
|
||||
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
|
||||
auto& device_ctx =
|
||||
ctx.template device_context<platform::CPUDeviceContext>();
|
||||
math::SetConstant<platform::CPUDeviceContext, T> zero;
|
||||
zero(device_ctx, output, static_cast<T>(0.0));
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
framework::TensorCopy(*input, ctx.GetPlace(), output);
|
||||
return;
|
||||
}
|
||||
|
||||
float ratio_h =
|
||||
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
float ratio_w =
|
||||
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if ("bilinear" == interp_method) {
|
||||
BilinearInterpolation<T>(*input, output, ratio_h, ratio_w, in_h, in_w, n,
|
||||
c, out_h, out_w);
|
||||
} else if ("nearest" == interp_method) {
|
||||
NearestNeighborInterpolate<T>(*input, output, ratio_h, ratio_w, n, c,
|
||||
out_h, out_w);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InterpolateGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
|
||||
std::string interp_method = ctx.Attr<std::string>("interp_method");
|
||||
int out_h = ctx.Attr<int>("out_h");
|
||||
int out_w = ctx.Attr<int>("out_w");
|
||||
auto out_size = ctx.Input<Tensor>("OutSize");
|
||||
if (out_size != nullptr) {
|
||||
auto out_size_data = out_size->data<int>();
|
||||
out_h = out_size_data[0];
|
||||
out_w = out_size_data[1];
|
||||
}
|
||||
|
||||
const int n = input->dims()[0];
|
||||
const int c = input->dims()[1];
|
||||
const int in_h = input->dims()[2];
|
||||
const int in_w = input->dims()[3];
|
||||
|
||||
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
|
||||
auto& device_ctx =
|
||||
ctx.template device_context<platform::CPUDeviceContext>();
|
||||
math::SetConstant<platform::CPUDeviceContext, T> zero;
|
||||
zero(device_ctx, input_grad, static_cast<T>(0.0));
|
||||
|
||||
if (in_h == out_h && in_w == out_w) {
|
||||
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
|
||||
return;
|
||||
}
|
||||
|
||||
float ratio_h =
|
||||
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
||||
float ratio_w =
|
||||
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
||||
|
||||
if ("bilinear" == interp_method) {
|
||||
BilinearInterpolationGrad<T>(*output_grad, input_grad, ratio_h, ratio_w,
|
||||
in_h, in_w, n, c, out_h, out_w);
|
||||
} else if ("nearest" == interp_method) {
|
||||
NearestNeighborInterpolateGrad<T>(*output_grad, input_grad, ratio_h,
|
||||
ratio_w, n, c, out_h, out_w);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue