fix problems

revert-3824-remove_grad_op_type
zchen0211 8 years ago
parent 6f235553fd
commit bfeecfd3d2

@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/scatter_op.h" #include "paddle/operators/scatter_op.h"
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"

@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/scatter_op.h" #include "paddle/operators/scatter_op.h"

@ -82,7 +82,7 @@ def get_numeric_gradient(op,
def product(dim): def product(dim):
return reduce(lambda a, b: a * b, dim, 1) return reduce(lambda a, b: a * b, dim, 1)
def copy_tensor(): def restore_inputs():
for var_name in input_values: for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor() tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace()) tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
@ -97,7 +97,7 @@ def get_numeric_gradient(op,
# we use a for loop to compute the gradient of every element. # we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size): for i in xrange(tensor_size):
if in_place: if in_place:
copy_tensor() restore_inputs()
# get one input element throw it's index i. # get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i) origin = tensor_to_check.get_float_element(i)
@ -108,7 +108,7 @@ def get_numeric_gradient(op,
# plus delta to this element, run op and get the sum of the result tensor. # plus delta to this element, run op and get the sum of the result tensor.
if in_place: if in_place:
copy_tensor() restore_inputs()
x_neg = origin - delta x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg) tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output() y_neg = get_output()

Loading…
Cancel
Save