fix comments

add_depthwiseConv_op_gpu
xzl 7 years ago
parent 6e17babe49
commit 84ded49d66

@ -361,6 +361,9 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ(
output->dims()[1] % input->dims()[1], 0,
"The output channels must be a multiple of the input channels");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");

@ -203,8 +203,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& filter, std::vector<int>& strides,
std::vector<int>& paddings, framework::Tensor* output) {
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
@ -244,7 +245,8 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& input,
const framework::Tensor& filter,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
@ -284,7 +286,8 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* filter_grad) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];

@ -29,8 +29,9 @@ template <typename DeviceContext, typename T>
class DepthwiseConvFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter, std::vector<int>& strides,
std::vector<int>& paddings, framework::Tensor* output);
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output);
};
template <typename DeviceContext, typename T>
@ -39,7 +40,8 @@ class DepthwiseConvInputGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad);
};
@ -48,7 +50,8 @@ class DepthwiseConvFilterGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* filter_grad);
};

Loading…
Cancel
Save