Merge pull request #7027 from reyoung/feature/rnn_gradient_check

Feature/rnn gradient check
del_some_in_makelist
Yu Yang 7 years ago committed by GitHub
commit 8b91174c83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -116,9 +116,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
auto height = dout_tensor.dims()[0];
auto slice = dx_tensor.Slice(0, static_cast<int>(height));
framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
if (dx_tensor.dims()[0] < height) {
if (dx_tensor.dims()[0] > height) {
auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}

@ -37,11 +37,11 @@ class SumKernel : public framework::OpKernel<T> {
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = context.Output<Tensor>("Out");
auto *out = context.Output<LoDTensor>("Out");
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
}
auto result = EigenVector<T>::Flatten(*out);
if (!in_place) {
math::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(), out,

@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp {
auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) {
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);

@ -77,10 +77,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
} else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor;
}
return py::buffer_info(
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.place()),
sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()),
dims_outside, strides);
} else {
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);

@ -1,12 +1,31 @@
import numpy as np
import contextlib
from framework import Program, default_main_program
from . import core
from framework import Program, default_main_program, Parameter, Variable
__all__ = ['Executor', 'g_scope']
__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope']
g_scope = core.Scope()
def global_scope():
return g_scope
def switch_scope(scope):
global g_scope
ex = g_scope
g_scope = scope
return ex
@contextlib.contextmanager
def scope_guard(scope):
ex = switch_scope(scope)
yield
switch_scope(ex)
def as_numpy(tensor):
if isinstance(tensor, list):
return [as_numpy(t) for t in tensor]
@ -117,7 +136,7 @@ class Executor(object):
raise TypeError()
if scope is None:
scope = g_scope
scope = global_scope()
program = program.clone()
global_block = program.global_block()

@ -170,7 +170,7 @@ def main():
exe.run(fluid.default_startup_program())
embedding_param = fluid.g_scope.find_var(embedding_name).get_tensor()
embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor()
embedding_param.set(
load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place)

@ -0,0 +1,29 @@
import paddle.v2.fluid as fluid
__all__ = ['many_times', 'prog_scope']
def many_times(times):
def __impl__(fn):
def __fn__(*args, **kwargs):
for _ in range(times):
fn(*args, **kwargs)
return __fn__
return __impl__
def prog_scope():
def __impl__(fn):
def __fn__(*args, **kwargs):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
fn(*args, **kwargs)
return __fn__
return __impl__

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save