Merge pull request #5419 from reyoung/feature/shrink_memory_op
Feature/shrink memory opmobile_baidu
commit
2a76b42e44
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/lod_tensor_array.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class ArrayOp : public framework::OperatorBase {
|
||||
public:
|
||||
ArrayOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
protected:
|
||||
size_t GetOffset(const framework::Scope &scope,
|
||||
const platform::DeviceContext &dev_ctx) const {
|
||||
auto *i = scope.FindVar(Input("I"));
|
||||
PADDLE_ENFORCE(i != nullptr, "I must be set");
|
||||
auto &i_tensor = i->Get<framework::LoDTensor>();
|
||||
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
|
||||
size_t offset;
|
||||
if (platform::is_gpu_place(i_tensor.place())) {
|
||||
// FIXME: Avoid copy from GPU to CPU
|
||||
framework::Tensor t;
|
||||
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
|
||||
dev_ctx.Wait();
|
||||
offset = static_cast<size_t>(*t.data<int64_t>());
|
||||
} else {
|
||||
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,149 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "paddle/framework/lod_rank_table.h"
|
||||
#include "paddle/operators/array_operator.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class ShrinkRNNMemoryOp : public ArrayOp {
|
||||
public:
|
||||
ShrinkRNNMemoryOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: ArrayOp(type, inputs, outputs, attrs) {}
|
||||
|
||||
void Run(const framework::Scope &scope,
|
||||
const platform::DeviceContext &dev_ctx) const override {
|
||||
auto *x_var = scope.FindVar(Input("X"));
|
||||
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
|
||||
auto &x_tensor = x_var->Get<framework::LoDTensor>();
|
||||
size_t offset = this->GetOffset(scope, dev_ctx);
|
||||
auto *rank_table_var = scope.FindVar(Input("RankTable"));
|
||||
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
|
||||
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
|
||||
|
||||
auto &rank_items = rank_table.items();
|
||||
int dst_num_rows =
|
||||
std::lower_bound(rank_items.begin(), rank_items.end(), offset,
|
||||
[](const framework::LoDRankTable::TableItem &a,
|
||||
size_t b) { return a.length > b; }) -
|
||||
rank_items.begin();
|
||||
|
||||
auto *out_var = scope.FindVar(Output("Out"));
|
||||
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
|
||||
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
|
||||
if (dst_num_rows != 0) {
|
||||
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "");
|
||||
AddInput("RankTable", "");
|
||||
AddInput("I", "");
|
||||
AddOutput("Out", "");
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE(context->HasInput("X"));
|
||||
PADDLE_ENFORCE(context->HasInput("I"));
|
||||
PADDLE_ENFORCE(context->HasInput("RankTable"));
|
||||
context->SetOutputDim("Out", context->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class ShrinkRNNMemoryGradOp : public ArrayOp {
|
||||
public:
|
||||
ShrinkRNNMemoryGradOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: ArrayOp(type, inputs, outputs, attrs) {}
|
||||
|
||||
void Run(const framework::Scope &scope,
|
||||
const platform::DeviceContext &dev_ctx) const override {
|
||||
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
|
||||
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
|
||||
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
|
||||
auto *x_var = scope.FindVar(Input("X"));
|
||||
PADDLE_ENFORCE(x_var != nullptr);
|
||||
|
||||
auto &x_tensor = x_var->Get<framework::LoDTensor>();
|
||||
auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
|
||||
dx_tensor.Resize(x_tensor.dims());
|
||||
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
|
||||
|
||||
if (dout_var == nullptr) { // dx_tensor fill zero
|
||||
math::set_constant(dev_ctx, &dx_tensor, 0.0f);
|
||||
} else {
|
||||
auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
|
||||
auto height = dout_tensor.dims()[0];
|
||||
dx_tensor.Slice(0, static_cast<int>(height))
|
||||
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
|
||||
if (dx_tensor.dims()[0] < height) {
|
||||
auto rest_tensor = dx_tensor.Slice(
|
||||
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
|
||||
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE(context->HasInput("X"));
|
||||
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
|
||||
context->SetOutputDim(framework::GradVarName("X"),
|
||||
context->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDescBind> Apply() const override {
|
||||
auto *op = new framework::OpDescBind();
|
||||
op->SetType("shrink_rnn_memory_grad");
|
||||
op->SetInput("X", Input("X"));
|
||||
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
op->SetAttrMap(Attrs());
|
||||
return std::unique_ptr<framework::OpDescBind>(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
|
||||
ops::ShrinkRNNMemoryInferShape,
|
||||
ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker);
|
||||
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
|
||||
ops::ShrinkRNNMemoryGradInferShape);
|
@ -0,0 +1,47 @@
|
||||
import unittest
|
||||
import paddle.v2.framework.core as core
|
||||
from paddle.v2.framework.executor import Executor
|
||||
import paddle.v2.framework.layers as layers
|
||||
from paddle.v2.framework.backward import append_backward_ops
|
||||
from paddle.v2.framework.framework import g_main_program
|
||||
import numpy
|
||||
|
||||
|
||||
class TestShrinkRNNMemory(unittest.TestCase):
|
||||
def test_shrink_rnn_memory(self):
|
||||
x = layers.data('x', shape=[100], data_type='float32')
|
||||
x.stop_gradient = False
|
||||
table = layers.lod_rank_table(x=x)
|
||||
i = layers.zeros(dtype='int64', shape=[1])
|
||||
mem1 = layers.shrink_memory(x=x, i=i, table=table)
|
||||
i = layers.increment(x=i)
|
||||
i.stop_gradient = True
|
||||
mem2 = layers.shrink_memory(x=mem1, i=i, table=table)
|
||||
i = layers.increment(x=i)
|
||||
i.stop_gradient = True
|
||||
mem3 = layers.shrink_memory(x=mem2, i=i, table=table)
|
||||
|
||||
cpu = core.CPUPlace()
|
||||
tensor = core.LoDTensor()
|
||||
tensor.set_lod([[0, 2, 5, 6]])
|
||||
tensor_np = numpy.random.random(size=(3, 100)).astype('float32')
|
||||
tensor.set(tensor_np, cpu)
|
||||
exe = Executor(cpu)
|
||||
outs = map(numpy.array,
|
||||
exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]))
|
||||
self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0]))
|
||||
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1]))
|
||||
self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2]))
|
||||
|
||||
mem3_mean = layers.mean(x=mem3)
|
||||
append_backward_ops(loss=mem3_mean)
|
||||
x_grad = map(numpy.array,
|
||||
exe.run(feed={'x': tensor},
|
||||
fetch_list=[
|
||||
g_main_program.global_block().var('x@GRAD')
|
||||
]))[0]
|
||||
self.assertAlmostEqual(1.0, x_grad.sum(), delta=0.1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue