commit
54c160aa72
@ -0,0 +1,87 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
template <typename T>
|
||||
class Channel {
|
||||
public:
|
||||
explicit Channel(std::size_t capacity) : capacity_(capacity) {}
|
||||
|
||||
void Send(T* channel_element) {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
|
||||
if (IsBounded()) {
|
||||
full_cond_var_.wait(lock, [this]() {
|
||||
bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true;
|
||||
return capacity_valid;
|
||||
});
|
||||
}
|
||||
channel_.push_back(std::move(*channel_element));
|
||||
|
||||
lock.unlock();
|
||||
empty_cond_var_.notify_one();
|
||||
}
|
||||
|
||||
T* Receive() {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
|
||||
|
||||
T* channel_element = std::move(channel_.front());
|
||||
channel_.pop_front();
|
||||
|
||||
NotifyAllSenders(&lock);
|
||||
return channel_element;
|
||||
}
|
||||
|
||||
size_t Size() {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
return channel_.size();
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
channel_.clear();
|
||||
|
||||
NotifyAllSenders(&lock);
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t capacity_;
|
||||
std::mutex mu_;
|
||||
std::condition_variable empty_cond_var_;
|
||||
std::condition_variable full_cond_var_;
|
||||
std::deque<T> channel_;
|
||||
|
||||
private:
|
||||
void NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
|
||||
if (IsBounded()) {
|
||||
lock->unlock();
|
||||
full_cond_var_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
bool IsBounded() const { return capacity_ > 0; }
|
||||
|
||||
bool IsCapacityFull() const { return channel_.size() >= capacity_; }
|
||||
};
|
||||
|
||||
} // namespace operator
|
||||
} // namespace paddle
|
@ -0,0 +1,56 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/feed_fetch_method.h"
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/framework/variable.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
void SetFeedVariable(Scope* scope, const LoDTensor& input,
|
||||
const std::string& var_name, size_t index) {
|
||||
// If var_name Variable is not found in GlobalScope, a new variable will
|
||||
// be created.
|
||||
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
|
||||
Variable* g_feed_value = scope->Var(var_name);
|
||||
auto& feed_inputs =
|
||||
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
|
||||
if (index >= feed_inputs.size()) {
|
||||
feed_inputs.resize(index + 1);
|
||||
}
|
||||
// shared data with input tensor
|
||||
feed_inputs[index].ShareDataWith(input);
|
||||
// set lod
|
||||
feed_inputs[index].set_lod(input.lod());
|
||||
}
|
||||
|
||||
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
|
||||
size_t index) {
|
||||
// Since we want to fetch LodTensor from a variable, the variable must
|
||||
// be created alreadly.
|
||||
Variable* g_fetch_value = scope.FindVar(var_name);
|
||||
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(),
|
||||
"Only %s can be invoked by GetFetchVariable",
|
||||
typeid(FeedFetchList).name());
|
||||
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>();
|
||||
auto& tensor = fetch_outputs[index];
|
||||
VLOG(3) << "Fetch " << var_name << " with index " << index
|
||||
<< " shape= " << tensor.dims();
|
||||
PADDLE_ENFORCE_LT(index, fetch_outputs.size());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,24 +1,93 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/threadpool.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr);
|
||||
std::once_flag ThreadPool::init_flag;
|
||||
std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
|
||||
std::once_flag ThreadPool::init_flag_;
|
||||
|
||||
ThreadPool* ThreadPool::GetInstance() {
|
||||
std::call_once(init_flag_, &ThreadPool::Init);
|
||||
return threadpool_.get();
|
||||
}
|
||||
|
||||
void ThreadPool::Init() {
|
||||
if (threadpool_.get() == nullptr) {
|
||||
// TODO(Yancey1989): specify the max threads number
|
||||
int num_threads = std::thread::hardware_concurrency();
|
||||
PADDLE_ENFORCE_GT(num_threads, 0);
|
||||
threadpool_.reset(new ThreadPool(num_threads));
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::ThreadPool(int num_threads)
|
||||
: total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
|
||||
threads_.resize(num_threads);
|
||||
for (auto& thread : threads_) {
|
||||
// TODO(Yancey1989): binding the thread on the specify CPU number
|
||||
thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
{
|
||||
// notify all threads to stop running
|
||||
running_ = false;
|
||||
scheduled_.notify_all();
|
||||
}
|
||||
|
||||
for (auto& t : threads_) {
|
||||
t->join();
|
||||
t.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPool::Wait() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
completed_.wait(lock, [=] { return Done() == true; });
|
||||
}
|
||||
|
||||
void ThreadPool::TaskLoop() {
|
||||
while (running_) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
|
||||
|
||||
if (!running_) {
|
||||
break;
|
||||
}
|
||||
// pop a task from the task queue
|
||||
auto task = std::move(tasks_.front());
|
||||
tasks_.pop();
|
||||
|
||||
--idle_threads_;
|
||||
lock.unlock();
|
||||
|
||||
// run the task
|
||||
task();
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
++idle_threads_;
|
||||
if (Done()) {
|
||||
completed_.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
@ -0,0 +1,95 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/operators/one_hot_op.h"
|
||||
#include "paddle/framework/framework.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class OneHotOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of OneHotOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of OneHotOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
||||
"Rank of Input(X) should be at least 2.");
|
||||
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U,
|
||||
"Last dimension of Input(X) should be 1.");
|
||||
|
||||
int depth = ctx->Attrs().Get<int>("depth");
|
||||
|
||||
PADDLE_ENFORCE_GT(depth, 0, "Should provide a positive depth (%d).", depth);
|
||||
|
||||
framework::DDim out_dims(x_dims);
|
||||
out_dims[out_dims.size() - 1] = depth;
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
ctx->ShareLoD("X", /* --> */ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
OneHotOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. "
|
||||
"The last dimension of X should be 1. Each value of X is an index "
|
||||
"to indicate the position.");
|
||||
AddOutput("Out",
|
||||
"(Tensor, Tensor<float>) Output tensor with same rank as X. "
|
||||
"The tensor consists of one-hot representations of values in X.");
|
||||
AddAttr<int>("depth",
|
||||
"A positive integer to specify the length of one-hot vector.");
|
||||
AddAttr<int>("dtype",
|
||||
"An integer to specify the data type of one-hot "
|
||||
"vector. The default value is FP32.")
|
||||
.SetDefault(paddle::framework::proto::DataType::FP32);
|
||||
AddComment(R"DOC(
|
||||
One Hot Operator. This operator creates the one-hot representations for input
|
||||
index values. The following example will help to explain the function of this
|
||||
operator:
|
||||
|
||||
X is a LoDTensor:
|
||||
X.lod = [[0, 1, 4]]
|
||||
X.shape = [4, 1]
|
||||
X.data = [[1], [1], [3], [0]]
|
||||
|
||||
set depth = 4
|
||||
|
||||
Out is a LoDTensor:
|
||||
Out.lod = [[0, 1, 4]]
|
||||
Out.shape = [4, 4]
|
||||
Out.data = [[0., 1., 0., 0.],
|
||||
[0., 1., 0., 0.],
|
||||
[0., 0., 0., 1.],
|
||||
[1., 0., 0., 0.]]
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(one_hot, ops::OneHotOp, ops::OneHotOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
one_hot, ops::OneHotKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::OneHotKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/operators/one_hot_op.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
#include "paddle/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||
|
||||
template <typename InT, typename OutT>
|
||||
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
|
||||
const int64_t numel, const int depth) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < numel) {
|
||||
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename InT>
|
||||
struct OneHotOpCUDAFunctor {
|
||||
const framework::LoDTensor* in_;
|
||||
framework::LoDTensor* out_;
|
||||
const DeviceContext& ctx_;
|
||||
int depth_;
|
||||
|
||||
OneHotOpCUDAFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
|
||||
int depth, const DeviceContext& ctx)
|
||||
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
|
||||
|
||||
template <typename OutT>
|
||||
void operator()() const {
|
||||
auto* p_in_data = in_->data<InT>();
|
||||
auto numel = in_->numel();
|
||||
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
|
||||
auto stream = ctx_.stream();
|
||||
math::set_constant(ctx_, out_, 0.0);
|
||||
|
||||
FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
|
||||
PADDLE_CUDA_NUM_THREADS,
|
||||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
|
||||
p_in_data, p_out_data, numel, depth_);
|
||||
}
|
||||
};
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
template <typename DeviceContext, typename T>
|
||||
class OneHotCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in = context.Input<LoDTensor>("X");
|
||||
auto* out = context.Output<LoDTensor>("Out");
|
||||
int depth = context.Attr<int>("depth");
|
||||
|
||||
framework::VisitDataType(
|
||||
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
|
||||
OneHotOpCUDAFunctor<DeviceContext, T>(
|
||||
in, out, depth, context.template device_context<DeviceContext>()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
one_hot, ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename InT>
|
||||
struct OneHotOpFunctor {
|
||||
const framework::LoDTensor* in_;
|
||||
framework::LoDTensor* out_;
|
||||
int depth_;
|
||||
const DeviceContext& ctx_;
|
||||
|
||||
OneHotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
|
||||
int depth, const DeviceContext& ctx)
|
||||
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
|
||||
|
||||
template <typename OutT>
|
||||
void operator()() const {
|
||||
auto* p_in_data = in_->data<InT>();
|
||||
auto numel = in_->numel();
|
||||
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
|
||||
math::set_constant(ctx_, out_, 0.0);
|
||||
|
||||
for (int i = 0; i < numel; ++i) {
|
||||
PADDLE_ENFORCE_GE(p_in_data[i], 0,
|
||||
"Illegal index value, should be at least 0.");
|
||||
PADDLE_ENFORCE_LT(p_in_data[i], depth_,
|
||||
"Illegal index value, should be less than depth (%d).",
|
||||
depth_);
|
||||
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
template <typename DeviceContext, typename T>
|
||||
class OneHotKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in = context.Input<LoDTensor>("X");
|
||||
auto* out = context.Output<LoDTensor>("Out");
|
||||
int depth = context.Attr<int>("depth");
|
||||
|
||||
framework::VisitDataType(
|
||||
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
|
||||
OneHotOpFunctor<DeviceContext, T>(
|
||||
in, out, depth, context.template device_context<DeviceContext>()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,135 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
|
||||
|
||||
def stacked_lstm_net(data,
|
||||
label,
|
||||
input_dim,
|
||||
class_dim=2,
|
||||
emb_dim=128,
|
||||
hid_dim=512,
|
||||
stacked_num=3):
|
||||
assert stacked_num % 2 == 1
|
||||
|
||||
emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim])
|
||||
# add bias attr
|
||||
|
||||
# TODO(qijun) linear act
|
||||
fc1 = fluid.layers.fc(input=emb, size=hid_dim)
|
||||
lstm1, cell1 = fluid.layers.dynamic_lstm(input=fc1, size=hid_dim)
|
||||
|
||||
inputs = [fc1, lstm1]
|
||||
|
||||
for i in range(2, stacked_num + 1):
|
||||
fc = fluid.layers.fc(input=inputs, size=hid_dim)
|
||||
lstm, cell = fluid.layers.dynamic_lstm(
|
||||
input=fc, size=hid_dim, is_reverse=(i % 2) == 0)
|
||||
inputs = [fc, lstm]
|
||||
|
||||
fc_last = fluid.layers.sequence_pool(input=inputs[0], pool_type='max')
|
||||
lstm_last = fluid.layers.sequence_pool(input=inputs[1], pool_type='max')
|
||||
|
||||
prediction = fluid.layers.fc(input=[fc_last, lstm_last],
|
||||
size=class_dim,
|
||||
act='softmax')
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002)
|
||||
optimize_ops, params_grads = adam_optimizer.minimize(avg_cost)
|
||||
accuracy = fluid.evaluator.Accuracy(input=prediction, label=label)
|
||||
return avg_cost, accuracy, accuracy.metrics[0], optimize_ops, params_grads
|
||||
|
||||
|
||||
def to_lodtensor(data, place):
|
||||
seq_lens = [len(seq) for seq in data]
|
||||
cur_len = 0
|
||||
lod = [cur_len]
|
||||
for l in seq_lens:
|
||||
cur_len += l
|
||||
lod.append(cur_len)
|
||||
flattened_data = np.concatenate(data, axis=0).astype("int64")
|
||||
flattened_data = flattened_data.reshape([len(flattened_data), 1])
|
||||
res = fluid.LoDTensor()
|
||||
res.set(flattened_data, place)
|
||||
res.set_lod([lod])
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
BATCH_SIZE = 100
|
||||
PASS_NUM = 5
|
||||
|
||||
word_dict = paddle.dataset.imdb.word_dict()
|
||||
print "loaded word dict successfully"
|
||||
dict_dim = len(word_dict)
|
||||
class_dim = 2
|
||||
|
||||
data = fluid.layers.data(
|
||||
name="words", shape=[1], dtype="int64", lod_level=1)
|
||||
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
|
||||
cost, accuracy, acc_out, optimize_ops, params_grads = stacked_lstm_net(
|
||||
data, label, input_dim=dict_dim, class_dim=class_dim)
|
||||
|
||||
train_data = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.imdb.train(word_dict), buf_size=1000),
|
||||
batch_size=BATCH_SIZE)
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
|
||||
|
||||
t = fluid.DistributeTranspiler()
|
||||
# all parameter server endpoints list for spliting parameters
|
||||
pserver_endpoints = os.getenv("PSERVERS")
|
||||
# server endpoint for current node
|
||||
current_endpoint = os.getenv("SERVER_ENDPOINT")
|
||||
# run as trainer or parameter server
|
||||
training_role = os.getenv(
|
||||
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
|
||||
t.transpile(
|
||||
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
|
||||
|
||||
if training_role == "PSERVER":
|
||||
if not current_endpoint:
|
||||
print("need env SERVER_ENDPOINT")
|
||||
exit(1)
|
||||
pserver_prog = t.get_pserver_program(current_endpoint)
|
||||
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
|
||||
exe.run(pserver_startup)
|
||||
exe.run(pserver_prog)
|
||||
elif training_role == "TRAINER":
|
||||
exe.run(fluid.default_startup_program())
|
||||
trainer_prog = t.get_trainer_program()
|
||||
for pass_id in xrange(PASS_NUM):
|
||||
accuracy.reset(exe)
|
||||
for data in train_data():
|
||||
cost_val, acc_val = exe.run(trainer_prog,
|
||||
feed=feeder.feed(data),
|
||||
fetch_list=[cost, acc_out])
|
||||
pass_acc = accuracy.eval(exe)
|
||||
print("cost=" + str(cost_val) + " acc=" + str(acc_val) +
|
||||
" pass_acc=" + str(pass_acc))
|
||||
if cost_val < 1.0 and acc_val > 0.8:
|
||||
exit(0)
|
||||
else:
|
||||
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue