parent
8a31b2eb75
commit
eb8252466b
@ -0,0 +1,21 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,89 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/platform/cuda_device_guard.h"
|
||||
#endif
|
||||
#include "paddle/fluid/framework/garbage_collector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
GarbageCollector::GarbageCollector(const platform::Place &place,
|
||||
size_t max_memory_size)
|
||||
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
|
||||
garbages_.reset(new GarbageQueue());
|
||||
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
|
||||
}
|
||||
|
||||
CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place,
|
||||
size_t max_memory_size)
|
||||
: GarbageCollector(place, max_memory_size) {}
|
||||
|
||||
void CPUGarbageCollector::ClearCallback(const std::function<void()> &callback) {
|
||||
callback();
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
UnsafeFastGPUGarbageCollector::UnsafeFastGPUGarbageCollector(
|
||||
const platform::CUDAPlace &place, size_t max_memory_size)
|
||||
: GarbageCollector(place, max_memory_size) {}
|
||||
|
||||
void UnsafeFastGPUGarbageCollector::ClearCallback(
|
||||
const std::function<void()> &callback) {
|
||||
callback();
|
||||
}
|
||||
|
||||
DefaultStreamGarbageCollector::DefaultStreamGarbageCollector(
|
||||
const platform::CUDAPlace &place, size_t max_memory_size)
|
||||
: GarbageCollector(place, max_memory_size) {}
|
||||
|
||||
void DefaultStreamGarbageCollector::Wait() const {
|
||||
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
|
||||
->WaitStreamCallback();
|
||||
}
|
||||
|
||||
void DefaultStreamGarbageCollector::ClearCallback(
|
||||
const std::function<void()> &callback) {
|
||||
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
|
||||
->AddStreamCallback(callback);
|
||||
}
|
||||
|
||||
StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
|
||||
size_t max_memory_size)
|
||||
: GarbageCollector(place, max_memory_size) {
|
||||
platform::CUDADeviceGuard guard(place.device);
|
||||
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
|
||||
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
|
||||
}
|
||||
|
||||
StreamGarbageCollector::~StreamGarbageCollector() {
|
||||
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
|
||||
platform::CUDADeviceGuard guard(place.device);
|
||||
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
|
||||
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
cudaStream_t StreamGarbageCollector::stream() const { return stream_; }
|
||||
|
||||
void StreamGarbageCollector::Wait() const { callback_manager_->Wait(); }
|
||||
|
||||
void StreamGarbageCollector::ClearCallback(
|
||||
const std::function<void()> &callback) {
|
||||
callback_manager_->AddCallback(callback);
|
||||
}
|
||||
#endif
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from test_eager_deletion_lstm_net import TestBase
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
def gru_net(data,
|
||||
label,
|
||||
dict_dim,
|
||||
emb_dim=128,
|
||||
hid_dim=128,
|
||||
hid_dim2=96,
|
||||
class_dim=2,
|
||||
emb_lr=400.0):
|
||||
emb = fluid.layers.embedding(
|
||||
input=data,
|
||||
size=[dict_dim, emb_dim],
|
||||
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
|
||||
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
|
||||
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
|
||||
gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
|
||||
gru_max_tanh = fluid.layers.tanh(gru_max)
|
||||
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
|
||||
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
return avg_cost
|
||||
|
||||
|
||||
class GRUTest(TestBase):
|
||||
def setUp(self):
|
||||
self.net = gru_net
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,111 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
|
||||
os.environ['CPU_NUM'] = '2'
|
||||
|
||||
import six
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2):
|
||||
if use_cuda and not core.is_compiled_with_cuda():
|
||||
print('Skip use_cuda=True because Paddle is not compiled with cuda')
|
||||
return
|
||||
|
||||
word_dict = paddle.dataset.imdb.word_dict()
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.imdb.train(word_dict), batch_size=batch_size)
|
||||
|
||||
data = fluid.layers.data(
|
||||
name="words", shape=[1], dtype="int64", lod_level=1)
|
||||
|
||||
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
|
||||
|
||||
cost = network(data, label, len(word_dict))
|
||||
optimizer = fluid.optimizer.Adagrad(learning_rate=0.2)
|
||||
optimizer.minimize(cost)
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
|
||||
reader = feeder.decorate_reader(
|
||||
train_reader, multi_devices=use_parallel_executor)
|
||||
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
if use_parallel_executor:
|
||||
train_exe = fluid.ParallelExecutor(
|
||||
use_cuda=use_cuda, loss_name=cost.name)
|
||||
fetch_list = [cost.name]
|
||||
else:
|
||||
train_exe = exe
|
||||
fetch_list = [cost]
|
||||
|
||||
for pass_id in six.moves.xrange(pass_num):
|
||||
batch_id = 0
|
||||
for data in reader():
|
||||
train_exe.run(feed=data,
|
||||
fetch_list=fetch_list if batch_id % 4 == 0 else [])
|
||||
batch_id += 1
|
||||
if batch_id > 16:
|
||||
break
|
||||
|
||||
|
||||
def lstm_net(data,
|
||||
label,
|
||||
dict_dim,
|
||||
emb_dim=128,
|
||||
hid_dim=128,
|
||||
hid_dim2=96,
|
||||
class_dim=2,
|
||||
emb_lr=30.0):
|
||||
emb = fluid.layers.embedding(
|
||||
input=data,
|
||||
size=[dict_dim, emb_dim],
|
||||
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
|
||||
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
|
||||
lstm_h, c = fluid.layers.dynamic_lstm(
|
||||
input=fc0, size=hid_dim * 4, is_reverse=False)
|
||||
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
|
||||
lstm_max_tanh = fluid.layers.tanh(lstm_max)
|
||||
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
|
||||
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
return avg_cost
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.net = lstm_net
|
||||
|
||||
def test_network(self):
|
||||
for use_cuda in [True, False]:
|
||||
for use_parallel_executor in [False, True]:
|
||||
print('network: {}, use_cuda: {}, use_parallel_executor: {}'.
|
||||
format(self.net.__name__, use_cuda,
|
||||
use_parallel_executor))
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
with fluid.scope_guard(core.Scope()):
|
||||
train(self.net, use_cuda, use_parallel_executor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue