Merge branch 'develop' into mkldnn

wangkuiyi-patch-1
Luo Tao 7 years ago
commit 79d555b9f2

@ -106,7 +106,7 @@ PaddlePaddle需要使用Docker环境完成编译这样可以免去单独安
- 学习 Docker 有多难?
理解 Docker 并不难,大概花十分钟看一下 `这篇文章 <https://zhuanlan.zhihu.com/p/19902938>`_ 。这可以帮您省掉花一小时安装和配置各种开发工具,以及切换机器时需要新安装的辛苦。别忘了 PaddlePaddle 更新可能导致需要新的开发工具。更别提简化问题复现带来的好处了。
理解 Docker 并不难,大概花十分钟看一下 `如何使用Docker <https://zhuanlan.zhihu.com/p/19902938>`_ 。这可以帮您省掉花一小时安装和配置各种开发工具,以及切换机器时需要新安装的辛苦。别忘了 PaddlePaddle 更新可能导致需要新的开发工具。更别提简化问题复现带来的好处了。
- 我可以用 IDE 吗?
@ -123,7 +123,7 @@ PaddlePaddle需要使用Docker环境完成编译这样可以免去单独安
- 可以并行编译吗?
是的。我们的 Docker image 运行一个 `Bash脚本 <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh>`_ 。这个脚本调用 `make -j$(nproc)` 来启动和 CPU 核一样多的进程来并行编译。
是的。我们的 Docker image 运行一个 `Paddle编译Bash脚本 <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh>`_ 。这个脚本调用 `make -j$(nproc)` 来启动和 CPU 核一样多的进程来并行编译。
- Docker 需要 sudo
@ -131,11 +131,11 @@ PaddlePaddle需要使用Docker环境完成编译这样可以免去单独安
- 在 Windows/MacOS 上编译很慢
Docker 在 Windows 和 MacOS 都可以运行。不过实际上是运行在一个 Linux 虚拟机上。可能需要注意给这个虚拟机多分配一些 CPU 和内存,以保证编译高效。具体做法请参考 `这个issue <https://github.com/PaddlePaddle/Paddle/issues/627>`_
Docker 在 Windows 和 MacOS 都可以运行。不过实际上是运行在一个 Linux 虚拟机上。可能需要注意给这个虚拟机多分配一些 CPU 和内存,以保证编译高效。具体做法请参考 `如何为Windows/Mac计算机上的Docker增加内存和虚拟机 <https://github.com/PaddlePaddle/Paddle/issues/627>`_
- 磁盘不够
本文中的例子里,`docker run` 命令里都用了 `--rm` 参数,这样保证运行结束之后的 containers 不会保留在磁盘上。可以用 `docker ps -a` 命令看到停止后但是没有删除的 containers。`docker build` 命令有时候会产生一些中间结果,是没有名字的 images也会占用磁盘。可以参考 `这篇文章 <https://zaiste.net/posts/removing_docker_containers/>`_ 来清理这些内容。
本文中的例子里,`docker run` 命令里都用了 `--rm` 参数,这样保证运行结束之后的 containers 不会保留在磁盘上。可以用 `docker ps -a` 命令看到停止后但是没有删除的 containers。`docker build` 命令有时候会产生一些中间结果,是没有名字的 images也会占用磁盘。可以参考 `如何删除Docker Container <https://zaiste.net/posts/removing_docker_containers/>`_ 来清理这些内容。
.. _compile_deps:
@ -195,7 +195,7 @@ BLAS
PaddlePaddle支持 `MKL <https://software.intel.com/en-us/intel-mkl>`_
`OpenBlAS <http://www.openblas.net/>`_ 两种BLAS库。默认使用MKL。如果使用MKL并且机器含有AVX2指令集
还会下载MKL-DNN数学库详细参考 `这里 <https://github.com/PaddlePaddle/Paddle/tree/develop/doc/design/mkldnn#cmake>`_
还会下载MKL-DNN数学库详细参考 `mkldnn设计文档 <https://github.com/PaddlePaddle/Paddle/tree/develop/doc/design/mkldnn#cmake>`_
如果关闭MKL则会使用OpenBLAS作为BLAS库。

@ -28,6 +28,9 @@ struct DataTypeMap {
};
static DataTypeMap* InitDataTypeMap();
// C++11 removes the need for manual locking. Concurrent execution shall wait if
// a static local variable is already being initialized.
// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex
static DataTypeMap& gDataTypeMap() {
static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
return *g_data_type_map_;

@ -42,7 +42,7 @@ void FuseVarsOpHandle::RunImpl() {
out_t->ShareDataWith(out_tensor->Slice(s, s + numel));
s += numel;
}
this->RunAndRecordEvent([this] {});
this->RunAndRecordEvent([] {});
}
std::string FuseVarsOpHandle::Name() const { return "fuse vars"; }

@ -151,7 +151,8 @@ class TRTConvertValidation {
// Compare two output
ASSERT_FALSE(fluid_out.empty());
for (size_t i = 0; i < fluid_out.size(); i++) {
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 1e-6);
// Loose the threshold for CI in different machine model.
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 2e-5);
}
}
}

@ -24,12 +24,12 @@ namespace operators {
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddComment(#OP_COMMENT); \
AddComment(OP_COMMENT); \
} \
}

@ -48,6 +48,13 @@ class CropOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", y_dim);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
@ -60,13 +67,19 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
"The input used as reference for cropping, "
"which is of the same dimensions as X.")
.AsDispensable();
AddInput("Offsets",
"The input used to describe offsets in runtime, which is a "
"1-D vector whose size equals to the rank of input 'X'. The "
"elements data type must be int.")
.AsDispensable();
AddOutput("Out",
"The output of crop op, "
"which is of the same dimensions as X.");
AddAttr<std::vector<int>>("offsets",
"A list<int> describing offsets to be cropped. "
"The size of offsets list should be the same as "
"the dimension size of input X.");
"the dimension size of input X.")
.SetDefault(std::vector<int>());
AddAttr<std::vector<int>>("shape",
"A list<int> describing the shape of output. "
"The size of shape list should be the same as "
@ -77,6 +90,17 @@ Crop Operator.
Crop input into output, as specified by offsets and shape.
There are two ways to set the offsets:
1. In runtime: Using the input 'Offsets', which is a Vairbale and can be
output of other operators. This way is suitable for
dynamic offsets.
2. In network configuration: Using the attribute 'offsets', which will be
set in Python configure script. This way is
suitable for fixed offsets.
You CANNOT use these two ways at the same time. An exception will be raised
if input 'Offset' is configured and meanwhile the attribute 'offsets' is
not empty.
There are two ways to set shape:
1. reference input: crop input X into the same shape as reference input.
The dimension of reference input should
@ -146,6 +170,15 @@ class CropOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context());
}
};
} // namespace operators

@ -27,6 +27,37 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor;
static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
std::vector<int> res;
int rank = ctx.Input<Tensor>("X")->dims().size();
if (ctx.HasInput("Offsets")) {
PADDLE_ENFORCE(ctx.Attr<std::vector<int>>("offsets").empty(),
"Input 'Offsets' and attribute 'offsets' should not be used "
"at the same time.");
const auto* offsets_tensor = ctx.Input<Tensor>("Offsets");
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1);
PADDLE_ENFORCE_EQ(
rank, offsets_tensor->dims()[0],
"Offsets size should be equal to dimension size of input tensor.");
const int* offsets_data;
framework::Tensor cpu_tmp_tensor;
if (platform::is_cpu_place(offsets_tensor->place())) {
offsets_data = offsets_tensor->data<int>();
} else {
framework::TensorCopySync(*offsets_tensor, platform::CPUPlace(),
&cpu_tmp_tensor);
offsets_data = cpu_tmp_tensor.data<int>();
}
res = std::vector<int>(offsets_data, offsets_data + rank);
} else {
res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
rank, res.size(),
"Offsets size should be equal to dimension size of input tensor.");
}
return res;
}
template <typename T>
class CropKernel : public framework::OpKernel<T> {
public:
@ -37,10 +68,7 @@ class CropKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
x->dims().size(), static_cast<int64_t>(offsets.size()),
"Offsets size should be equal to dimension size of input tensor.");
auto offsets = GetOffsets(context);
int64_t offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]);
@ -56,7 +84,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
if (d_x != nullptr) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(context.GetPlace());
auto offsets = context.Attr<std::vector<int>>("offsets");
auto offsets = GetOffsets(context);
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < D; ++i) {
paddings[i].first = offsets[i];

@ -80,7 +80,6 @@ class RequestHandler {
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
@ -113,13 +112,7 @@ class RequestHandler {
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
};
} // namespace detail

@ -63,16 +63,22 @@ bool RequestSendHandler::Handle(const std::string& varname,
PADDLE_THROW("sync: Can not find server side var");
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
}
return true;
}
void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,

@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
void ResetSparseVarRecorder();
private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
};
class RequestGetHandler final : public RequestHandler {

@ -60,6 +60,7 @@ class RPCServer {
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter();
protected:

@ -43,7 +43,8 @@ TEST(Gather, GatherData) {
auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
paddle::operators::CPUGather<int>(ctx, *src, *index, output);
delete cpu_place;
cpu_place = NULL;
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);

@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
@ -146,18 +143,12 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (framework::Variable *var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter();
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
} // while(true)
}

@ -77,6 +77,8 @@ TEST(math_function, gemm_trans_clbas) {
paddle::platform::CPUDeviceContext context(*cpu_place);
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3,
input2_ptr + 3, 3, 1, input3_ptr + 1, 4);
delete cpu_place;
cpu_place = NULL;
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);

@ -20,7 +20,6 @@ class RandomCropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
@ -36,11 +35,11 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Seed", "The random seed.");
AddOutput("Out", "The cropped instance batch.");
AddOutput("SeedOut", "The random seed after random cropping.")
.AsDispensable();
.AsIntermediate();
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
AddComment(R"DOC(
This operator takes a batch of instance, and do random cropping on each instance.
It means that cropping positions differs on each instance, which is determined
This operator takes a batch of instance, and do random cropping on each instance.
It means that cropping positions differs on each instance, which is determined
by an uniform random generator. All cropped instances have the same shape, which
is determined by the operator's attribute 'shape'.
)DOC");

@ -26,6 +26,7 @@ from trainer import BeginEpochEvent
from trainer import EndEpochEvent
from trainer import BeginStepEvent
from trainer import EndStepEvent
from trainer import CheckpointConfig
import inferencer
from inferencer import Inferencer

@ -363,6 +363,13 @@ class OpProtoHolder(object):
raise ValueError("Operator \"%s\" has not been registered." % type)
return self.op_proto_map[type]
@staticmethod
def generated_op_attr_names():
return {
core.op_proto_and_checker_maker.kOpRoleAttrName(),
core.op_proto_and_checker_maker.kOpRoleVarAttrName()
}
class Operator(object):
"""

@ -56,6 +56,8 @@ class Inferencer(object):
else:
self.exe = executor.Executor(self.place)
self.inference_program = self.inference_program.clone(for_test=True)
def infer(self, inputs, return_numpy=True):
"""
:param inputs: a map of {"input_name": input_var} that will be feed into the inference program

File diff suppressed because it is too large Load Diff

@ -15,16 +15,13 @@ import re
import cStringIO
import functools
import warnings
import string
from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable
from ..layer_helper import LayerHelper
__all__ = [
'deprecated',
'generate_layer_fn',
'autodoc',
]
__all__ = ['deprecated', 'generate_layer_fn', 'autodoc', 'templatedoc']
def _convert_(name):
@ -43,6 +40,10 @@ def _convert_(name):
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp)
def _generate_doc_string_(op_proto):
"""
Generate docstring by OpProto
@ -54,9 +55,6 @@ def _generate_doc_string_(op_proto):
str: the document string
"""
def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp)
if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`")
@ -75,7 +73,11 @@ def _generate_doc_string_(op_proto):
buf.write(str(each_input.dispensable))
buf.write('\n')
skip_attrs = OpProtoHolder.generated_op_attr_names()
for each_attr in op_proto.attrs:
if each_attr.name in skip_attrs:
continue
buf.write(' ')
buf.write(each_attr.name)
buf.write(' (')
@ -220,3 +222,49 @@ def autodoc(comment=""):
return func
return __impl__
def templatedoc():
"""
Decorator of layer function. It will use the docstring from the layer
function as the template. The template arguments are:
* ${comment}: The operator comment written in CPP.
* ${{name}_comment}: The comment of ${name} written with AddAttr, AddOutput,
and AddInput. The ${name} is Python snake style. i.e., xxx_xxx.
* ${{name}_type}: The type of ${name}.
Returns:
Decorated function.
"""
def __impl__(func):
op_proto = OpProtoHolder.instance().get_op_proto(func.__name__)
tmpl = string.Template(func.__doc__)
comment_lines = op_proto.comment.split("\n")
comment = ""
for line in comment_lines:
line = line.lstrip()
comment += line
comment += "\n"
args = {"comment": comment}
for each_input in op_proto.inputs:
input_name = _convert_(each_input.name)
args["{0}_comment".format(input_name)] = each_input.comment
args["{0}_type".format(input_name)] = "Variable"
for each_attr in op_proto.attrs:
input_name = _convert_(each_attr.name)
args["{0}_comment".format(input_name)] = each_attr.comment
args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type)
for each_opt in op_proto.outputs:
output_name = _convert_(each_opt.name)
args["{0}_comment".format(output_name)] = each_opt.comment
args["{0}_type".format(output_name)] = "Variable"
func.__doc__ = tmpl.substitute(args)
return func
return __impl__

@ -64,10 +64,6 @@ def auc(input, label, curve='ROC', num_thresholds=200):
topk_indices = helper.create_tmp_variable(dtype="int64")
topk_out, topk_indices = nn.topk(input, k=k)
auc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
if total is None:
total = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="accuracy",
inputs={

@ -19,9 +19,10 @@ from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
from layer_function_generator import autodoc
from layer_function_generator import autodoc, templatedoc
from tensor import concat
import utils
import random
__all__ = [
'fc',
@ -801,7 +802,22 @@ def gru_unit(input,
return updated_hidden, reset_hidden_pre, gate
@templatedoc()
def linear_chain_crf(input, label, param_attr=None):
"""
Linear Chain CRF.
${comment}
Args:
input(${emission_type}): ${emission_comment}
label(${label_type}): ${label_comment}
param_attr(ParamAttr): The attribute of the learnable parameter.
Returns:
${log_likelihood_comment}
"""
helper = LayerHelper('linear_chain_crf', **locals())
size = input.shape[1]
transition = helper.create_parameter(
@ -827,7 +843,19 @@ def linear_chain_crf(input, label, param_attr=None):
return log_likelihood
@templatedoc()
def crf_decoding(input, param_attr, label=None):
"""
${comment}
Args:
input(${emission_type}): ${emission_comment}
param_attr(ParamAttr): The parameter attribute for training.
label(${label_type}): ${label_comment}
Returns:
${viterbi_path_comment}
"""
helper = LayerHelper('crf_decoding', **locals())
transition = helper.get_parameter(param_attr.name)
viterbi_path = helper.create_tmp_variable(dtype=helper.input_dtype())
@ -4107,10 +4135,31 @@ def gather(input, index):
return out
def random_crop(input, shape, seed=1):
@templatedoc()
def random_crop(x, shape, seed=None):
"""
${comment}
Examples:
>>> img = fluid.layers.data("img", [3, 256, 256])
>>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224])
Args:
x(${x_type}): ${x_comment}
shape(${shape_type}): ${shape_comment}
seed(int|${seed_type}|None): ${seed_comment} By default, the seed will
get from `random.randint(-65536, 65535)`.
Returns:
${out_comment}
"""
helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
if seed is None:
seed = random.randint(-65536, 65535)
if isinstance(seed, int):
seed_value = seed
seed = helper.create_tmp_variable(dtype="int64")

@ -73,6 +73,7 @@ __all__ = [
'sum',
'polygon_box_transform',
'shape',
'maxout',
] + __activations__
for _OP in set(__all__):

@ -38,7 +38,7 @@ def inference_program():
return y_predict
def linear():
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = inference_program()
@ -104,7 +104,7 @@ def main(use_cuda):
# Directory for saving the trained model
params_dirname = "fit_a_line.inference.model"
train(use_cuda, linear, params_dirname)
train(use_cuda, train_program, params_dirname)
infer(use_cuda, inference_program, params_dirname)

@ -0,0 +1,75 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import os
import tempfile
class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.dirname = tempfile.mktemp()
self.max_num_checkpoints = 3
self.epoch_interval = 1
self.step_interval = 1
self.trainer_id = 0
self.chief = self.trainer_id == 0
self.place = fluid.CPUPlace()
self.epoch_id = 100
self.step_id = 20
def test_checkpoint(self):
self.save_checkpoint()
serial = fluid.io.get_latest_checkpoint_serial(self.dirname)
self.assertTrue(serial >= 0)
trainer_args = ["epoch_id", "step_id"]
epoch_id, step_id = fluid.io.load_trainer_args(
self.dirname, serial, self.trainer_id, trainer_args)
self.assertEqual(self.step_id, int(step_id))
self.assertEqual(self.epoch_id, int(epoch_id))
program = fluid.Program()
with fluid.program_guard(program):
exe = fluid.Executor(self.place)
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
self.assertFalse(os.path.isdir(self.dirname))
def save_checkpoint(self):
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
self.epoch_interval, self.step_interval)
trainer_args = {}
trainer_args["epoch_id"] = self.epoch_id
trainer_args["step_id"] = self.step_id
program = fluid.Program()
with fluid.program_guard(program):
program.global_block().create_var(
name="scale_0",
psersistable=True,
dtype="float32",
shape=[32, 32])
exe = fluid.Executor(self.place)
for i in xrange(10):
fluid.io.save_checkpoint(exe, config.checkpoint_dir,
self.trainer_id, trainer_args, program,
config.max_num_checkpoints)
if __name__ == '__main__':
unittest.main()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save