Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into blocking_queue_for_reader
commit
e057ba6877
@ -0,0 +1,10 @@
|
||||
==================================
|
||||
Data Reader Interface and DataSets
|
||||
==================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
data/data_reader.rst
|
||||
data/image.rst
|
||||
data/dataset.rst
|
@ -0,0 +1,72 @@
|
||||
=====================
|
||||
Data Reader Interface
|
||||
=====================
|
||||
|
||||
|
||||
DataTypes
|
||||
=========
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.dense_array
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.integer_value
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.integer_value_sequence
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.integer_value_sub_sequence
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_binary_vector
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_binary_vector_sequence
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_binary_vector_sub_sequence
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_float_vector
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_float_vector_sequence
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_float_vector_sub_sequence
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_non_value_slot
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: paddle.v2.data_type.sparse_value_slot
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: paddle.v2.data_type.InputType
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
DataFeeder
|
||||
==========
|
||||
|
||||
.. automodule:: paddle.v2.data_feeder
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
Reader
|
||||
======
|
||||
|
||||
.. automodule:: paddle.v2.reader
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. automodule:: paddle.v2.reader.creator
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
minibatch
|
||||
=========
|
||||
|
||||
.. automodule:: paddle.v2.minibatch
|
||||
:members:
|
||||
:noindex:
|
@ -0,0 +1,82 @@
|
||||
Dataset
|
||||
=======
|
||||
|
||||
.. automodule:: paddle.dataset
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
mnist
|
||||
+++++
|
||||
|
||||
.. automodule:: paddle.dataset.mnist
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
cifar
|
||||
+++++
|
||||
|
||||
.. automodule:: paddle.dataset.cifar
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
conll05
|
||||
+++++++
|
||||
|
||||
.. automodule:: paddle.dataset.conll05
|
||||
:members: get_dict,get_embedding,test
|
||||
:noindex:
|
||||
|
||||
imdb
|
||||
++++
|
||||
|
||||
.. automodule:: paddle.dataset.imdb
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
imikolov
|
||||
++++++++
|
||||
|
||||
.. automodule:: paddle.dataset.imikolov
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
movielens
|
||||
+++++++++
|
||||
|
||||
.. automodule:: paddle.dataset.movielens
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: paddle.dataset.movielens.MovieInfo
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: paddle.dataset.movielens.UserInfo
|
||||
:noindex:
|
||||
|
||||
sentiment
|
||||
+++++++++
|
||||
|
||||
.. automodule:: paddle.dataset.sentiment
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
uci_housing
|
||||
+++++++++++
|
||||
|
||||
.. automodule:: paddle.dataset.uci_housing
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
wmt14
|
||||
+++++
|
||||
|
||||
.. automodule:: paddle.dataset.wmt14
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
wmt16
|
||||
+++++
|
||||
|
||||
.. automodule:: paddle.dataset.wmt16
|
||||
:members:
|
||||
:noindex:
|
@ -0,0 +1,5 @@
|
||||
Image Interface
|
||||
===============
|
||||
|
||||
.. automodule:: paddle.v2.image
|
||||
:members:
|
@ -0,0 +1,175 @@
|
||||
# Varient Length supported RNN Design
|
||||
For the learning of variable length sequences, the existing mainstream frameworks such as tensorflow, pytorch, caffe2, mxnet and so on all use padding.
|
||||
|
||||
Different-length sequences in a mini-batch will be padded with zeros and transformed to same length.
|
||||
|
||||
The existing RNN implementations of the PaddlePaddle is `RecurrentLayerGroup`,
|
||||
which supports the variable length sequences without padding.
|
||||
This doc will design fluid's RNN based on this idea.
|
||||
|
||||
## Multi-layer sequence data format `LODTensor`
|
||||
At present, Paddle stores data in one mini-batch in one-dimensional array.
|
||||
|
||||
`Argument.sequenceStartPositions` is used to store information for each sentence.
|
||||
|
||||
In Paddle, `Argument.subSequenceStartPositions` is used to store 2 levels of sequence information, while higher dimensional sequences can not be supported.
|
||||
|
||||
In order to support the storage of `N-level` sequences, we define sequence information as the following data structure.
|
||||
|
||||
|
||||
```c++
|
||||
std::shared_ptr<std::vector<std::vector<int>>> lod_start_pos_;
|
||||
```
|
||||
|
||||
Or more clearly defined here
|
||||
|
||||
```c++
|
||||
typedef std::vector<int> level_t;
|
||||
std::vector<level_t> lod_start_pos;
|
||||
```
|
||||
Each `level_t` here stores a level of offset information consistent with paddle's current practice.
|
||||
|
||||
In order to transmit sequence information more transparently, we have introduced a new tensor called `LODTensor`[1].
|
||||
Its tensor-related interfaces all inherit directly from `Tensor`, but it also adds serial-related interfaces.
|
||||
Thus, when working with a `LODTensor`, ordinary `Op` is used directly as `Tensor`.
|
||||
The `Op` of the operation sequence will additionally operate the relevant interface of the `LODTensor` variable-length sequence operation.
|
||||
|
||||
The definition of `LODTensor` is as follows:
|
||||
|
||||
|
||||
```c++
|
||||
class LODTensor : public Tensor {
|
||||
public:
|
||||
size_t Levels() const { return seq_start_positions_.size(); }
|
||||
size_t Elements(int level = 0) const {
|
||||
return seq_start_positions_[level].size();
|
||||
}
|
||||
// slice of level[elem_begin: elem_end]
|
||||
// NOTE low performance in slice seq_start_positions_.
|
||||
// TODO should call Tensor's Slice.
|
||||
LODTensor LODSlice(int level, int elem_begin, int elem_end) const;
|
||||
|
||||
// slice with tensor's data shared with this.
|
||||
LODTensor LODSliceShared(int level, int elem_begin, int elem_end) const;
|
||||
|
||||
// copy other's lod_start_pos_, to share LOD info.
|
||||
// NOTE the LOD info sould not be changed.
|
||||
void ShareConstLODFrom(const LODTensor &other) {
|
||||
lod_start_pos_ = other.lod_start_pos_;
|
||||
}
|
||||
// copy other's lod_start_pos_'s content, free to mutate.
|
||||
void ShareMutableLODFrom(const LODTensor &other) {
|
||||
lod_start_pos_ = std::make_shared <
|
||||
std::vector<std::vector<int>>(other.lod_start_pos_.begin(),
|
||||
other.lod_start_pos_.end());
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::vector<std::vector<int>>> lod_start_pos_;
|
||||
};
|
||||
```
|
||||
Among them, `lod_start_pos_` uses `shared_ptr` to reduce the cost of storage and replication.
|
||||
`LODTensor` can be thought as an extension of `Tensor`, which is almost completely compatible with the original `Tensor`.
|
||||
|
||||
## How to support the framework
|
||||
### Replace `Tensor` with `LoDTensor`
|
||||
To implement the passing of `LODTensor`, most `Tensor` in the framework need to be replaced with `LODTensor`.
|
||||
Simple implementation, directly **replace all previous `Tensor` with `LODTensor`** , where you can directly modify the `Tensor` interface created in `pybind.cc`.
|
||||
|
||||
In addition, the user may need to perceive the existence of a sequence (such as the sequence of the visualization needs to parse the output sequence in the model), so some of the serial operation APIs also need to be exposed to the python layer.
|
||||
|
||||
### Transmit `lod_start_pos` along with the Op call chain
|
||||
`lod_start_pos` is passed along with the Op call chain
|
||||
The framework needs to support the following features to implement the transmit of `lod_start_pos`:
|
||||
|
||||
1. Implement the transfer as `shared_ptr`
|
||||
- Do not modify the contents of `lod_start_pos` as a consumer
|
||||
- Modify producer of `lod_start_pos` as producer
|
||||
- Conventions consumer only needs to copy `shared_ptr` passed over
|
||||
- producer needs to create its own independent memory to store its own independent modifications and expose `shared_ptr` to subsequent consumer
|
||||
- Since the transfer process is implemented by copying `shared_ptr`, the framework only needs to pass `lod_start_pos` once.
|
||||
|
||||
2. Op is transparent enough not to sense `lod_start_pos`
|
||||
3. Producer Op that needs to modify `lod_start_pos` can update its `lod_start_pos` data when `Run`
|
||||
|
||||
## sorted by length
|
||||
After sorting by length, the batch size from the forward time step will naturally decrement, and you can directly plug it into Net to do the batch calculation.
|
||||
|
||||
For example, the original input:
|
||||
|
||||
```
|
||||
origin:
|
||||
xxxx
|
||||
xx
|
||||
xxx
|
||||
|
||||
-> sorted:
|
||||
xxxx
|
||||
xxx
|
||||
xx
|
||||
```
|
||||
|
||||
After `SegmentInputs`, there will be 4 time steps, the input of each time step is as follows (vertical arrangement)
|
||||
|
||||
```
|
||||
0 1 2 3
|
||||
x x x x
|
||||
x x x
|
||||
x x
|
||||
```
|
||||
|
||||
In order to track the changes before and after sorting, use here
|
||||
|
||||
```c++
|
||||
struct SortedSeqItem {
|
||||
void *start{nullptr};
|
||||
void *end{nullptr};
|
||||
};
|
||||
|
||||
std::vector<SortedSeqItem> sorted_seqs;
|
||||
```
|
||||
To track the position of the sequence after sorting, and add a new interface
|
||||
|
||||
```c++
|
||||
std::vector<SortedSeqItem> SortBySeqLen(const LODTensor& tensor);
|
||||
```
|
||||
Due to the sequence of input sequences, the following existing interfaces need to be modified:
|
||||
|
||||
- InitMemories, memory needs to be rearranged according to `sorted_seqs`
|
||||
- SetmentInputs
|
||||
- ConcatOutputs
|
||||
|
||||
In addition, because `sorted_seqs` needs to be multiplexed with `RecurrentGradientOp`, it will become a new output of `RecurrentOp`.
|
||||
It is passed in as an input to `RecurrentGradientOp`.
|
||||
|
||||
## InitMemories
|
||||
Due to the sequence change, the order of the elements on the `boot_memories` batch also needs to be rearranged accordingly.
|
||||
|
||||
## SegmentInputs
|
||||
|
||||
`SegmentInputs` relies on the information of `sorted_seqs` to cut the original sequence from the horizontal to the input of each step in the sorted sequence order.
|
||||
|
||||
the transition is as follows:
|
||||
```
|
||||
origin:
|
||||
xxxx
|
||||
xx
|
||||
xxx
|
||||
|
||||
|
|
||||
|
|
||||
\ /
|
||||
!
|
||||
0 1 2 3
|
||||
x x x x
|
||||
x x x
|
||||
x x
|
||||
```
|
||||
## ConcatOutputs
|
||||
`ConcatOutputs` needs
|
||||
|
||||
- Restore the output of each time step back to the original input sequence order (to prevent the order of Infer phase from being upset)
|
||||
- Concat each sequence as a regular mini-batch representation
|
||||
|
||||
## references
|
||||
1. [Level of details](https://en.wikipedia.org/wiki/Level_of_detail)
|
After Width: | Height: | Size: 29 KiB |
@ -0,0 +1,131 @@
|
||||
# Background
|
||||
|
||||
[ONNX (Open Neural Network Exchange)](https://github.com/onnx/onnx) bridges different deep learning frameworks by providing an open source graph format for models. The models trained in other frameworks can be converted into the ONNX format to execute inference by utilizing the built-in operators in ONNX - this is called a **frontend**. With the inverse conversion (called a **backend**), different frameworks can share any models supported by ONNX in principle. Now most mainstream frameworks have joined the ONNX community, e.g. Caffe2, PyTorch, and MXNet etc. And there is a momentum driving more and more vendors to begin supporting ONNX or even choose ONNX as the only machine learning runtime in their devices.
|
||||
|
||||
Therefore, it is necessary to enable the conversion between PaddlePaddle and ONNX. This design doc is aimed at implementing a convertor, mainly for converting between **Fluid** models and ONNX (it is very likely that we may support older v2 models in the future). A complete convertor should be bidirectional - with a frontend AND a backend, but considering the importance, the we will start with the frontend i.e. Fluid models to ONNX models.
|
||||
|
||||
|
||||
# How it works
|
||||
|
||||
ONNX has a [working list of operators](https://github.com/onnx/onnx/blob/master/docs/Operators.md) which is versioned.
|
||||
|
||||
When prioritizing implementation of a frontend over a backend, choice of coverage of Fluid -> ONNX operators comes down to choices of models to be supported (see section `Supported models`). Eventually, this will allow us to reach a really-wide coverage of all operators.
|
||||
|
||||
Here are a few major considerations when it comes to converting models:
|
||||
|
||||
- **Op-level conversion**: How to map the inputs, attributes, and outputs of each Paddle operator to those of the ONNX operator. In several cases, these require transformations. For each direction (frontend vs. backend), a different conversion mapping is needed.
|
||||
- **Parameters (weights) initialization**: Setting initial parameters on different nodes.
|
||||
- **Tensor data type mapping** (Note: Some ONNX data types are not supported in Fluid)
|
||||
- **Network representation adaption**: Fluid `ProgramDesc` include nested blocks. Since ONNX is free of nesting, the `ProgramDesc` ops need to be traversed to only include ops from the global scope in the root block. The variables used as inputs and outputs should also be in this scope.
|
||||
- **Model validation**: There are two kinds of validations that are necessary:
|
||||
1. We need to ensure that the inference outputs of the ops in run inside a model are the same as those when running the ONNX converted ops through an alternative ONNX backend.
|
||||
2. Checking to see if the generated nodes on the graph are validated by the internal ONNX checkers.
|
||||
- **Versioning**: ONNX versions its op listing over versions. In fact, it has versioning on 3 different levels: ops, graphs, and ONNX models. This requires that we are conscious about versioning the convertor and updating tests and op convertor logic for each release. It also implies that we release pre-trained ONNX models upon each version release.
|
||||
|
||||
One thing that makes this conversion more feasible in Fluid's case is the use of a static IR - the `ProgramDesc` - as opposed to a dynamic graph, as created in the cases of frameworks like PyTorch.
|
||||
|
||||
|
||||
# Project structure
|
||||
|
||||
<p align="center">
|
||||
<img src="./images/project_structure.png"/>
|
||||
</p>
|
||||
|
||||
The project contains four important parts:
|
||||
|
||||
* **fluid**: The directory that contains wrappers for fluid related APIs. Fluid has provided some low-level APIs to parse or generate the inference model. However, directly using these low-level APIs makes the code tediously long. This module wraps low-level APIs to provide simplified interfaces.
|
||||
|
||||
* **onnx**: This is a Python package provided by ONNX containing helpers for creating nodes, graphs, and eventually binary protobuf models with initializer parameters.
|
||||
|
||||
* **onnx_fluid**: Contains two-way mapping (Fluid -> ONNX ops and ONNX -> Fluid ops). Called from `convert.py`, the program uses this mapping along with modifier functions to construct ONNX nodes with the help of ONNX's `make_node` helper. It also contains mapping between datatypes and tensor deprecation / amplification logic.
|
||||
|
||||
* **convert.py**: The interface exposed to users. This will traverse the global program blocks/variables and construct the write-able model.
|
||||
|
||||
|
||||
# Usage
|
||||
The converter should be designed to very easy-to-use. Bidirectional conversion between a Fluid inference model and an ONNX binary model will be supported. Model validation will also provided to verify the correctness of converted model.
|
||||
|
||||
* Convert Fluid inference model to ONNX binary model
|
||||
|
||||
```
|
||||
python convert.py --fluid_model <fluid inference model> --onnx_model <ONNX model> validate True
|
||||
```
|
||||
|
||||
* Validate the converted model
|
||||
|
||||
```
|
||||
python validate.py --fluid_model <fluid inference model> --onnx_model <ONNX model>
|
||||
```
|
||||
|
||||
The conversion and model validation will be completed consecutively, finally output a readable model structure description. And for the converse conversion, users only need to exchange the input and output.
|
||||
|
||||
|
||||
# Challenges and mitigation
|
||||
|
||||
## Cycles
|
||||
|
||||
Cycles are unsupported in ONNX. In Paddle, the `while` op is the most prominent example of a cycle.
|
||||
|
||||
*Resolution*: We won't support models with `while`s which can't be substituted until ONNX adds support for such ops.
|
||||
|
||||
## Sequences
|
||||
|
||||
Sequence processing operators like `sequence_expand`, `sequence_reshape`, `sequence_concat`, and `sequence_pool` are not supported by ONNX as well, because they do not support non-padded datatypes like LoDTensors.
|
||||
|
||||
*Resolution*: Since the runtimes using our ONNX exported graphs won't be using LoDTensors in the first place, such sequence operators should be mapped to ONNX ops that will do the necessary transposing ops with the knowledge of the padding and shape of the Tensors.
|
||||
|
||||
## Ops that can't easily be mapped
|
||||
|
||||
There are ops that just aren't possible to map today:
|
||||
|
||||
**Control flow operators**
|
||||
|
||||
Paddle supports control flow ops like `If/Else` and `Switch` (if we ignore the CSP operations like `select` for now). ONNX has `If` support in the experimental phase.
|
||||
|
||||
*Resolution*: Map Paddle's `If/Else` to ONNX's `If`, but ignore other control flow operators until ONNX brings support for them.
|
||||
|
||||
|
||||
**Non-existent in Fluid**
|
||||
|
||||
There are several ONNX operators that are not available in Fluid today, e.g. `InstanceNormalization`, `RandomUniform`, `Unsqueeze`, etc.
|
||||
|
||||
*Resolution*: For the initial phase, we can choose to not support ops that our models don't care for and are subsequently not available in Fluid. However, for ops that we think might be necessary for Fluid users also, we must implement them on our side and support the ONNX conversion to them. This list is TBD.
|
||||
|
||||
|
||||
**Concurrency**
|
||||
|
||||
ONNX does not have any considerations for concurrency right now.
|
||||
|
||||
*Resolution*: There are two ways to approach this:
|
||||
|
||||
a. We choose to not support concurrent models.
|
||||
b. We only support `go_op`s (basically threads) shallowly. This could mean that we enqueue `go_op` ops prior to gradient calculations OR even prior to the entire graph, and that's it - since `go_op`s do not have support for backprop anyways. One of the core target use cases of `go_op`: batch reading - can be handled through this approach.
|
||||
|
||||
|
||||
**Overloaded in Fluid**
|
||||
|
||||
There are ops in ONNX whose job can't be accomplished by a single corresponding Paddle operator (e.g. ), but a collection of operators.
|
||||
|
||||
*Resolution*: Chain multiple Paddle operators.
|
||||
|
||||
|
||||
## Lack of LoDTensors
|
||||
|
||||
As stated above, ONNX only supports simple Tensor values.
|
||||
|
||||
*Resolution*: Deprecate to plain old numpy-able tensors.
|
||||
|
||||
|
||||
## Reconstruction from deprecated ONNX ops
|
||||
|
||||
For higher-level Fluid ops, such as a few offered by the `nn` layer that do not have direct corresponding mappings but can be converted to ONNX by chaining a series of ops without cycles, it would be useful to map them back to the higher-level Fluid ops once converted back from the deprecated ONNX graphs.
|
||||
|
||||
*Resolution*: Graphs that have the deprecation from Paddle -> ONNX. When converting back from ONNX, if we encounter the identical graphs by doing a forward search, we can replace the subgraphs with the matching ONNX op.
|
||||
|
||||
|
||||
# Supported models
|
||||
|
||||
As mentioned above, potential risks may come from the conversion of sequence-related models, including the LodTensor, ```if/else``` and ```while``` operator. So a good choice is to focus on some important feedforward models first, then implement some simple recurrent models.
|
||||
|
||||
- Feedforward models: common models selected in PaddleBook, e.g. VGG, ResNet and some other models proposed by application teams.
|
||||
- Recurrent models: language model, stacked LSTMs etc.
|
@ -0,0 +1,74 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable> // NOLINT
|
||||
#include <deque>
|
||||
#include <mutex> // NOLINT
|
||||
#include <utility>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
template <typename T>
|
||||
class BlockingQueue {
|
||||
public:
|
||||
void Push(const T &item) {
|
||||
{
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
q_.emplace_back(item);
|
||||
}
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
void Extend(const U &items) {
|
||||
{
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
for (auto &item : items) {
|
||||
q_.emplace_back(item);
|
||||
}
|
||||
}
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
std::deque<T> PopAll(size_t ms, bool *timeout) {
|
||||
auto time =
|
||||
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
*timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); });
|
||||
std::deque<T> ret;
|
||||
if (!*timeout) {
|
||||
std::swap(ret, q_);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
T Pop() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [=] { return !q_.empty(); });
|
||||
T rc(std::move(q_.front()));
|
||||
q_.pop_front();
|
||||
return rc;
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
std::deque<T> q_;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue