parent
d206582337
commit
6a62b9d8a0
@ -0,0 +1,115 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/temporal_shift_op.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
class TemporalShiftOp: public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||||
|
"Input(X) of TemporalShiftOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||||
|
"Output(Out) of TemporalShiftOp should not be null.");
|
||||||
|
|
||||||
|
auto dim_x = ctx->GetInputDim("X");
|
||||||
|
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
|
||||||
|
"Input(X) rank should be 4 in shape of [N*T, C, H, W].");
|
||||||
|
|
||||||
|
int seg_num = ctx->Attrs().Get<int>("seg_num");
|
||||||
|
PADDLE_ENFORCE_GT(seg_num, 0,
|
||||||
|
"Attr(seg_num) should be greater then 0.");
|
||||||
|
|
||||||
|
if (ctx->IsRuntime()) {
|
||||||
|
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0,
|
||||||
|
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->SetOutputDim("Out", dim_x);
|
||||||
|
ctx->ShareLoD("X", "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||||
|
ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X",
|
||||||
|
"The input tensor of temporal shift operator. "
|
||||||
|
"This is a 4-D tensor with shape of [N*T, C, H, W]. "
|
||||||
|
"While N is the batch size, T is the temporal segment "
|
||||||
|
"number, C is the channel number, H is the height of "
|
||||||
|
"features and W is the width of features.");
|
||||||
|
AddOutput("Out",
|
||||||
|
"The output tensor of temporal shift operator. "
|
||||||
|
"This is a 4-D tensor in the same shape with Input(X).");
|
||||||
|
|
||||||
|
AddAttr<int>("seg_num",
|
||||||
|
"The temporal segment number, this should be a positive "
|
||||||
|
"interger.");
|
||||||
|
|
||||||
|
AddComment(R"DOC(
|
||||||
|
This operator calculates the temporal shift features for Input(X).
|
||||||
|
|
||||||
|
For details of spectral normalization, please refer to paper:
|
||||||
|
`Temporal Shift Module <arxiv.org/abs/1802.0595://arxiv.org/abs/1811.08383>`_ .
|
||||||
|
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TemporalShiftOpGrad: public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||||
|
"Input(Out@GRAD) should not be null");
|
||||||
|
auto dim_x = ctx->GetInputDim("X");
|
||||||
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||||
|
ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker,
|
||||||
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||||
|
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
|
||||||
|
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
|
||||||
|
ops::TemporalShiftKernel<double>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(temporal_shift_grad, ops::TemporalShiftGradKernel<float>,
|
||||||
|
ops::TemporalShiftGradKernel<double>);
|
@ -0,0 +1,151 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/temporal_shift_op.h"
|
||||||
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
|
||||||
|
const int tchw, const int chw, const int hw, const int w, const int t, const int c) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
int src_it = 0;
|
||||||
|
for (; tid < ntchw; tid += stride) {
|
||||||
|
int in = tid / tchw;
|
||||||
|
int it = (tid % tchw) / chw;
|
||||||
|
int ic = (tid % chw) / hw;
|
||||||
|
int ih = (tid % hw) / w;
|
||||||
|
int iw = tid % w;
|
||||||
|
|
||||||
|
if (ic < c / 4) {
|
||||||
|
src_it = it - 1;
|
||||||
|
} else if (ic < c / 2) {
|
||||||
|
src_it = it + 1;
|
||||||
|
} else {
|
||||||
|
src_it = it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src_it < 0 || src_it >= t) {
|
||||||
|
output[tid] = 0;
|
||||||
|
} else {
|
||||||
|
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||||
|
output[tid] = input[src_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw,
|
||||||
|
const int tchw, const int chw, const int hw, const int w, const int t, const int c) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
int src_it = 0;
|
||||||
|
for (; tid < ntchw; tid += stride) {
|
||||||
|
int in = tid / tchw;
|
||||||
|
int it = (tid % tchw) / chw;
|
||||||
|
int ic = (tid % chw) / hw;
|
||||||
|
int ih = (tid % hw) / w;
|
||||||
|
int iw = tid % w;
|
||||||
|
|
||||||
|
if (ic < c / 4) {
|
||||||
|
src_it = it - 1;
|
||||||
|
} else if (ic < c / 2) {
|
||||||
|
src_it = it + 1;
|
||||||
|
} else {
|
||||||
|
src_it = it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src_it >= 0 && src_it < t) {
|
||||||
|
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||||
|
input_grad[src_idx] = output_grad[tid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||||
|
"This kernel only runs on GPU device.");
|
||||||
|
auto* input = ctx.Input<Tensor>("X");
|
||||||
|
auto* output = ctx.Output<Tensor>("Out");
|
||||||
|
int t = ctx.Attr<int>("seg_num");
|
||||||
|
|
||||||
|
const int nt = input->dims()[0];
|
||||||
|
const int c = input->dims()[1];
|
||||||
|
const int h = input->dims()[2];
|
||||||
|
const int w = input->dims()[3];
|
||||||
|
|
||||||
|
const int hw = h * w;
|
||||||
|
const int chw = c * hw;
|
||||||
|
const int tchw = t * chw;
|
||||||
|
const int ntchw = nt * chw;
|
||||||
|
|
||||||
|
const T* input_data = input->data<T>();
|
||||||
|
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||||
|
|
||||||
|
int pixelNum = nt * chw;
|
||||||
|
int grid_dim = (pixelNum + 512 - 1) / 512;
|
||||||
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
||||||
|
|
||||||
|
KeTemporalShiftFw<
|
||||||
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
||||||
|
input_data, output_data, ntchw, tchw, chw, hw, w, t, c);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||||
|
int t = ctx.Attr<int>("seg_num");
|
||||||
|
|
||||||
|
const int nt = output_grad->dims()[0];
|
||||||
|
const int c = output_grad->dims()[1];
|
||||||
|
const int h = output_grad->dims()[2];
|
||||||
|
const int w = output_grad->dims()[3];
|
||||||
|
|
||||||
|
const int hw = h * w;
|
||||||
|
const int chw = c * hw;
|
||||||
|
const int tchw = t * chw;
|
||||||
|
const int ntchw = nt * chw;
|
||||||
|
|
||||||
|
const T* output_grad_data = output_grad->data<T>();
|
||||||
|
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||||
|
|
||||||
|
int pixelNum = nt * chw;
|
||||||
|
int grid_dim = (pixelNum + 512 - 1) / 512;
|
||||||
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
||||||
|
|
||||||
|
KeTemporalShiftBw<
|
||||||
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
||||||
|
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
|
||||||
|
ops::TemporalShiftOpCUDAKernel<double>);
|
||||||
|
REGISTER_OP_CUDA_KERNEL(temporal_shift_grad,
|
||||||
|
ops::TemporalShiftGradOpCUDAKernel<float>,
|
||||||
|
ops::TemporalShiftGradOpCUDAKernel<double>);
|
@ -0,0 +1,117 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, int iw,
|
||||||
|
const int tchw, const int chw, const int hw, const int w) {
|
||||||
|
return in * tchw + it * chw + ic * hw + ih * w + iw;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class TemporalShiftKernel: public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* input = ctx.Input<Tensor>("X");
|
||||||
|
auto* output = ctx.Output<Tensor>("Out");
|
||||||
|
int t = ctx.Attr<int>("seg_num");
|
||||||
|
|
||||||
|
const int nt = input->dims()[0];
|
||||||
|
const int c = input->dims()[1];
|
||||||
|
const int h = input->dims()[2];
|
||||||
|
const int w = input->dims()[3];
|
||||||
|
|
||||||
|
const int hw = h * w;
|
||||||
|
const int chw = c * hw;
|
||||||
|
const int tchw = t * chw;
|
||||||
|
|
||||||
|
const T* input_data = input->data<T>();
|
||||||
|
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||||
|
|
||||||
|
int src_it = 0;
|
||||||
|
for (int i = 0; i < output->numel(); i++) {
|
||||||
|
int in = i / tchw;
|
||||||
|
int it = (i % tchw) / chw;
|
||||||
|
int ic = (i % chw) / hw;
|
||||||
|
int ih = (i % hw) / w;
|
||||||
|
int iw = i % w;
|
||||||
|
|
||||||
|
if (ic < c / 4) {
|
||||||
|
src_it = it - 1;
|
||||||
|
} else if (ic < c / 2) {
|
||||||
|
src_it = it + 1;
|
||||||
|
} else {
|
||||||
|
src_it = it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src_it < 0 || src_it >= t) {
|
||||||
|
output_data[i] = 0;
|
||||||
|
} else {
|
||||||
|
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||||
|
output_data[i] = input_data[src_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class TemporalShiftGradKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||||
|
int t = ctx.Attr<int>("seg_num");
|
||||||
|
|
||||||
|
const int nt = output_grad->dims()[0];
|
||||||
|
const int c = output_grad->dims()[1];
|
||||||
|
const int h = output_grad->dims()[2];
|
||||||
|
const int w = output_grad->dims()[3];
|
||||||
|
|
||||||
|
const int hw = h * w;
|
||||||
|
const int chw = c * hw;
|
||||||
|
const int tchw = t * chw;
|
||||||
|
|
||||||
|
const T* output_grad_data = output_grad->data<T>();
|
||||||
|
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||||
|
|
||||||
|
int src_it = 0;
|
||||||
|
for (int i = 0; i < output_grad->numel(); i++) {
|
||||||
|
int in = i / tchw;
|
||||||
|
int it = (i % tchw) / chw;
|
||||||
|
int ic = (i % chw) / hw;
|
||||||
|
int ih = (i % hw) / w;
|
||||||
|
int iw = i % w;
|
||||||
|
|
||||||
|
if (ic < c / 4) {
|
||||||
|
src_it = it - 1;
|
||||||
|
} else if (ic < c / 2) {
|
||||||
|
src_it = it + 1;
|
||||||
|
} else {
|
||||||
|
src_it = it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src_it >= 0 && src_it < t) {
|
||||||
|
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||||
|
input_grad_data[src_idx] = output_grad_data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
from paddle.fluid import core
|
||||||
|
|
||||||
|
|
||||||
|
def temporal_shift(x, seg_num):
|
||||||
|
shape = x.shape
|
||||||
|
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
|
||||||
|
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant')
|
||||||
|
slice1 = pad_x[:, :seg_num, :shape[1]//4, :, :]
|
||||||
|
slice2 = pad_x[:, 2:seg_num+2, shape[1]//4:shape[1]//2, :, :]
|
||||||
|
slice3 = pad_x[:, 1:seg_num+1, shape[1]//2:, :, :]
|
||||||
|
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
|
||||||
|
return concat_x.reshape(shape)
|
||||||
|
|
||||||
|
class TestTemporalShift(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.initTestCase()
|
||||||
|
self.op_type = 'temporal_shift'
|
||||||
|
x = np.random.random(self.x_shape).astype('float32')
|
||||||
|
|
||||||
|
self.attrs = {
|
||||||
|
"seg_num": self.seg_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.inputs = {
|
||||||
|
"X": x,
|
||||||
|
}
|
||||||
|
|
||||||
|
output = temporal_shift(x, self.seg_num)
|
||||||
|
self.outputs = {"Out": output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad_ignore_uv(self):
|
||||||
|
self.check_grad(
|
||||||
|
['X'],
|
||||||
|
'Out',
|
||||||
|
max_relative_error=0.01)
|
||||||
|
|
||||||
|
def initTestCase(self):
|
||||||
|
self.x_shape = (6, 4, 4, 4)
|
||||||
|
self.seg_num = 3
|
||||||
|
|
||||||
|
class TestTemporalShift2(TestTemporalShift):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.x_shape = (4, 9, 7, 7)
|
||||||
|
self.seg_num = 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestTemporalShift2(TestTemporalShift):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.x_shape = (3, 10, 5, 5)
|
||||||
|
self.seg_num = 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue