modify for code review by qingqing

release/0.11.0
sweetsky0901 8 years ago
parent ee4a5d2117
commit c218961a6b

@ -60,9 +60,9 @@ public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad) {
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];

@ -114,9 +114,9 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad) {
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];

@ -14,8 +14,6 @@ limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
@ -37,9 +35,9 @@ class Unpool2dMaxGradFunctor {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad);
const framework::Tensor& output_grad,
framework::Tensor * input_grad);
};
} // namespace math
} // namespace operators

@ -78,7 +78,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y");
std::string unpoolingtype =
std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpoolingtype");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");

Loading…
Cancel
Save