parent
bcb80756af
commit
72be7a6151
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
#include "paddle/fluid/recordio/scanner.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
class RecordIOFileReader : public framework::FileReader {
|
||||
public:
|
||||
RecordIOFileReader(const std::string& filename,
|
||||
const std::vector<framework::DDim>& shapes)
|
||||
: FileReader(shapes),
|
||||
scanner_(filename),
|
||||
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
|
||||
platform::CPUPlace())) {}
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override {
|
||||
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
|
||||
}
|
||||
|
||||
bool HasNext() const override { return scanner_.HasNext(); }
|
||||
|
||||
void ReInit() override { scanner_.Reset(); }
|
||||
|
||||
private:
|
||||
recordio::Scanner scanner_;
|
||||
const platform::DeviceContext& dev_ctx_;
|
||||
};
|
||||
|
||||
class CreateRecordIOReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
||||
const auto& ranks = Attr<std::vector<int>>("ranks");
|
||||
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
||||
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
||||
int(shape_concat.size()),
|
||||
"The accumulate of all ranks should be equal to the "
|
||||
"shape concat's length.");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
std::string filename = Attr<std::string>("filename");
|
||||
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new RecordIOFileReader(filename, shapes));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
|
||||
public:
|
||||
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: FileReaderMakerBase(op_proto, op_checker) {
|
||||
AddAttr<std::string>("filename", "The filename of record io reader");
|
||||
AddComment(R"DOC(
|
||||
CreateRecordIOReader Operator
|
||||
|
||||
Create a reader from a record io file
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace reader = paddle::operators::reader;
|
||||
|
||||
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
|
||||
reader::CreateRecordIOReaderOp,
|
||||
reader::CreateRecordIOReaderOpMaker);
|
@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/pybind/recordio.h"
|
||||
#include <fstream>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/recordio/writer.h"
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
class RecordIOWriter {
|
||||
public:
|
||||
RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
|
||||
size_t max_num_record)
|
||||
: stream_(filename), writer_(&stream_, compressor, max_num_record) {}
|
||||
|
||||
void AppendTensor(const framework::LoDTensor& tensor) {
|
||||
tensors_.push_back(tensor);
|
||||
}
|
||||
|
||||
void CompleteAppendTensor() {
|
||||
auto& ctx =
|
||||
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
||||
framework::WriteToRecordIO(writer_, tensors_, ctx);
|
||||
tensors_.clear();
|
||||
}
|
||||
|
||||
void Close() {
|
||||
PADDLE_ENFORCE(tensors_.empty());
|
||||
writer_.Flush();
|
||||
stream_.close();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<framework::LoDTensor> tensors_;
|
||||
std::ofstream stream_;
|
||||
recordio::Writer writer_;
|
||||
};
|
||||
|
||||
void BindRecordIOWriter(py::module& m) {
|
||||
py::class_<RecordIOWriter> writer(m, "RecordIOWriter", "");
|
||||
py::enum_<recordio::Compressor>(writer, "Compressor", "")
|
||||
.value("Snappy", recordio::Compressor::kSnappy)
|
||||
.value("NoCompress", recordio::Compressor::kNoCompress);
|
||||
|
||||
writer
|
||||
.def("__init__",
|
||||
[](RecordIOWriter& self, const std::string& filename,
|
||||
recordio::Compressor compressor, size_t max_num_record) {
|
||||
new (&self) RecordIOWriter(filename, compressor, max_num_record);
|
||||
})
|
||||
.def("append_tensor", &RecordIOWriter::AppendTensor)
|
||||
.def("complete_append_tensor", &RecordIOWriter::CompleteAppendTensor)
|
||||
.def("close", &RecordIOWriter::Close);
|
||||
}
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,26 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
extern void BindRecordIOWriter(py::module& m);
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import core
|
||||
|
||||
|
||||
class RecordIOWriter(object):
|
||||
def __init__(self,
|
||||
filename,
|
||||
compressor=core.RecordIOWriter.Compressor.Snappy,
|
||||
max_num_records=1000):
|
||||
self.filename = filename
|
||||
self.compressor = compressor
|
||||
self.max_num_records = max_num_records
|
||||
self.writer = None
|
||||
|
||||
def __enter__(self):
|
||||
self.writer = core.RecordIOWriter(self.filename, self.compressor,
|
||||
self.max_num_records)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is not None:
|
||||
return False
|
||||
else:
|
||||
self.writer.close()
|
||||
|
||||
def append_tensor(self, tensor):
|
||||
self.writer.append_tensor(tensor)
|
||||
|
||||
def complete_append_tensor(self):
|
||||
self.writer.complete_append_tensor()
|
||||
|
||||
|
||||
def convert_reader_to_recordio_file(
|
||||
filename,
|
||||
reader_creator,
|
||||
feeder,
|
||||
compressor=core.RecordIOWriter.Compressor.Snappy,
|
||||
max_num_records=1000,
|
||||
feed_order=None):
|
||||
writer = RecordIOWriter(filename, compressor, max_num_records)
|
||||
with writer:
|
||||
for batch in reader_creator():
|
||||
res = feeder.feed(batch)
|
||||
if feed_order is None:
|
||||
for each in res:
|
||||
writer.append_tensor(res[each])
|
||||
else:
|
||||
for each in feed_order:
|
||||
writer.append_tensor(res[each])
|
||||
writer.complete_append_tensor()
|
@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.v2.dataset.mnist as mnist
|
||||
import paddle.v2 as paddle
|
||||
|
||||
|
||||
class TestRecordIO(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with fluid.program_guard(fluid.Program()):
|
||||
reader = paddle.batch(mnist.train(), batch_size=32)
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[
|
||||
fluid.layers.data(
|
||||
name='image', shape=[784]), fluid.layers.data(
|
||||
name='label', shape=[1], dtype='int64')
|
||||
],
|
||||
place=fluid.CPUPlace())
|
||||
fluid.recordio_writer.convert_reader_to_recordio_file(
|
||||
'./mnist.recordio',
|
||||
reader,
|
||||
feeder,
|
||||
feed_order=['image', 'label'])
|
||||
|
||||
def testMain(self):
|
||||
data_file = fluid.layers.open_recordio_file(
|
||||
'./mnist.recordio',
|
||||
shapes=[[-1, 784], [-1, 1]],
|
||||
lod_levels=[0, 0],
|
||||
dtypes=['float32', 'int64'])
|
||||
img, label = fluid.layers.read_file(data_file)
|
||||
|
||||
hidden = fluid.layers.fc(input=img, size=100, act='tanh')
|
||||
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_loss = fluid.layers.mean(loss)
|
||||
|
||||
fluid.optimizer.SGD(learning_rate=1e-3).minimize(avg_loss)
|
||||
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
exe.run(fluid.default_startup_program())
|
||||
avg_loss_np, = exe.run(fetch_list=[avg_loss])
|
||||
print avg_loss_np
|
Loading…
Reference in new issue