You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
4.4 KiB
162 lines
4.4 KiB
7 years ago
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include "paddle/framework/ddim.h"
|
||
7 years ago
|
#include "paddle/framework/lod_tensor_array.h"
|
||
7 years ago
|
|
||
|
namespace paddle {
|
||
|
namespace framework {
|
||
|
|
||
7 years ago
|
class ReaderBase {
|
||
7 years ago
|
public:
|
||
7 years ago
|
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
|
||
|
PADDLE_ENFORCE(!shapes_.empty());
|
||
|
}
|
||
|
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
|
||
7 years ago
|
virtual bool HasNext() const = 0;
|
||
|
|
||
7 years ago
|
virtual void ReInit() = 0;
|
||
|
|
||
7 years ago
|
DDim shape(size_t idx) const;
|
||
|
std::vector<DDim> shapes() const { return shapes_; }
|
||
|
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
|
||
7 years ago
|
|
||
7 years ago
|
virtual ~ReaderBase() {}
|
||
7 years ago
|
|
||
|
protected:
|
||
|
std::vector<DDim> shapes_;
|
||
7 years ago
|
};
|
||
7 years ago
|
|
||
7 years ago
|
class FileReader : public ReaderBase {
|
||
|
public:
|
||
7 years ago
|
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
|
||
7 years ago
|
};
|
||
|
|
||
7 years ago
|
class DecoratedReader : public ReaderBase {
|
||
7 years ago
|
public:
|
||
7 years ago
|
explicit DecoratedReader(ReaderBase* reader)
|
||
|
: ReaderBase(reader->shapes()), reader_(reader) {
|
||
7 years ago
|
PADDLE_ENFORCE_NOT_NULL(reader_);
|
||
|
}
|
||
|
|
||
7 years ago
|
bool HasNext() const override { return reader_->HasNext(); }
|
||
|
|
||
7 years ago
|
void ReInit() override { reader_->ReInit(); }
|
||
|
|
||
7 years ago
|
protected:
|
||
|
ReaderBase* reader_;
|
||
|
};
|
||
|
|
||
7 years ago
|
// file readers
|
||
|
|
||
7 years ago
|
template <typename T>
|
||
7 years ago
|
class RandomDataGenerator : public FileReader {
|
||
7 years ago
|
public:
|
||
7 years ago
|
RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max)
|
||
7 years ago
|
: FileReader(shapes), min_(min), max_(max) {
|
||
7 years ago
|
PADDLE_ENFORCE_LE(
|
||
|
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
|
||
7 years ago
|
unsigned int seed = std::random_device()();
|
||
|
engine_.seed(seed);
|
||
|
dist_ = std::uniform_real_distribution<float>(min_, max_);
|
||
7 years ago
|
}
|
||
|
|
||
7 years ago
|
void ReadNext(std::vector<LoDTensor>* out) override {
|
||
|
out->clear();
|
||
|
out->reserve(shapes_.size());
|
||
7 years ago
|
for (const DDim& shape : shapes_) {
|
||
7 years ago
|
PADDLE_ENFORCE_GE(
|
||
|
shape.size(), 2,
|
||
7 years ago
|
"The rank of reader's output data should be 2 at least.(Now it's %d)",
|
||
7 years ago
|
shape.size());
|
||
7 years ago
|
LoDTensor out_tensor;
|
||
|
out_tensor.Resize(shape);
|
||
|
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
|
||
7 years ago
|
int64_t numel = product(shape);
|
||
|
for (int64_t i = 0; i < numel; ++i) {
|
||
7 years ago
|
data[i] = dist_(engine_);
|
||
7 years ago
|
}
|
||
7 years ago
|
out->push_back(out_tensor);
|
||
7 years ago
|
}
|
||
|
}
|
||
|
|
||
|
bool HasNext() const override { return true; }
|
||
7 years ago
|
|
||
7 years ago
|
void ReInit() override { return; }
|
||
|
|
||
7 years ago
|
private:
|
||
|
float min_;
|
||
|
float max_;
|
||
7 years ago
|
std::minstd_rand engine_;
|
||
|
std::uniform_real_distribution<float> dist_;
|
||
7 years ago
|
};
|
||
|
|
||
7 years ago
|
// decorated readers
|
||
7 years ago
|
|
||
7 years ago
|
class ShuffleReader : public DecoratedReader {
|
||
7 years ago
|
public:
|
||
7 years ago
|
ShuffleReader(ReaderBase* reader, int buffer_size)
|
||
7 years ago
|
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
|
||
7 years ago
|
buffer_.reserve(buffer_size);
|
||
|
}
|
||
7 years ago
|
|
||
7 years ago
|
void ReadNext(std::vector<LoDTensor>* out) override;
|
||
7 years ago
|
|
||
|
private:
|
||
7 years ago
|
int buffer_size_;
|
||
7 years ago
|
std::vector<std::vector<LoDTensor>> buffer_;
|
||
7 years ago
|
size_t iteration_pos_;
|
||
7 years ago
|
};
|
||
|
|
||
7 years ago
|
class BatchReader : public DecoratedReader {
|
||
7 years ago
|
public:
|
||
7 years ago
|
BatchReader(ReaderBase* reader, int batch_size)
|
||
7 years ago
|
: DecoratedReader(reader), batch_size_(batch_size) {
|
||
7 years ago
|
buffer_.reserve(batch_size_);
|
||
|
}
|
||
|
|
||
7 years ago
|
void ReadNext(std::vector<LoDTensor>* out) override;
|
||
7 years ago
|
|
||
|
private:
|
||
7 years ago
|
int batch_size_;
|
||
7 years ago
|
std::vector<std::vector<LoDTensor>> buffer_;
|
||
7 years ago
|
};
|
||
7 years ago
|
|
||
7 years ago
|
// The ReaderHolder is used as readers' unified wrapper,
|
||
|
// making it easier to access different type readers in Variables.
|
||
7 years ago
|
class ReaderHolder {
|
||
|
public:
|
||
|
void Reset(ReaderBase* reader) { reader_.reset(reader); }
|
||
|
|
||
|
ReaderBase* Get() const { return reader_.get(); }
|
||
|
|
||
7 years ago
|
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
|
||
7 years ago
|
bool HasNext() const { return reader_->HasNext(); }
|
||
7 years ago
|
void ReInit() { reader_->ReInit(); }
|
||
7 years ago
|
|
||
|
DDim shape(size_t idx) const { return reader_->shape(idx); }
|
||
|
std::vector<DDim> shapes() const { return reader_->shapes(); }
|
||
7 years ago
|
void set_shapes(const std::vector<DDim>& shapes) {
|
||
|
reader_->set_shapes(shapes);
|
||
|
}
|
||
7 years ago
|
|
||
|
private:
|
||
|
std::unique_ptr<ReaderBase> reader_;
|
||
|
};
|
||
|
|
||
7 years ago
|
} // namespace framework
|
||
|
} // namespace paddle
|