commit
d1f5f49826
@ -0,0 +1,99 @@
|
||||
# Design Doc: Functions, Operators, and Layers
|
||||
|
||||
In a DL system, we can compose one or more fine grained operators into a coarse grained one. For example, the FC layer can be composed of a multiplication operator and an add operator.
|
||||
|
||||
Historically, some fine grained operations are known as operators, and some coarse level ones are known as layers. But we need a well-defined separation.
|
||||
|
||||
In general, operators are those very fine grained operations, e.g., mul and add. In the implementation, we can write them as C++ functions:
|
||||
|
||||
```c++
|
||||
template <typename T> T add(T x, T y) { return x + y; }
|
||||
template <typename T> T mul(T x, T y) { return x * y; }
|
||||
```
|
||||
|
||||
Then we can wrap them into operators which are C++ classes and can be created from Python bindings by name. A C macro can do this. For example, the following macro invocation
|
||||
|
||||
```c++
|
||||
#define MAKE_FUNCTION_OPERATOR(mul);
|
||||
```
|
||||
|
||||
generates
|
||||
|
||||
```c++
|
||||
template <typename T> class mulOp : public OperatorBase {...};
|
||||
REGISTER_OP(mulOp<float32>, "mul");
|
||||
```
|
||||
|
||||
so that in Python we can create operator mul by:
|
||||
|
||||
```python
|
||||
X1 = Var()
|
||||
X2 = Var()
|
||||
Y = Var()
|
||||
paddle.cpp.create_operator("mul", input=[X1, X2], output=Y)
|
||||
```
|
||||
|
||||
Also, at the same time, we can compose a coarse level C++ operator class by composing functions `mul` and `add`:
|
||||
|
||||
```c++
|
||||
template <typename T>
|
||||
class FCOp : public OperatorBase {
|
||||
public:
|
||||
void Run(...) {
|
||||
add(mul(Input<T>("X"), Input<T>("W")), Input<T>("b");
|
||||
}
|
||||
};
|
||||
REGISTER_OP(FCOp, "fc");
|
||||
```
|
||||
|
||||
We need to support such composition in Python as well. To do so, we need a higher level Python wrapping of operator creation than `paddle.cpp.create_operator`. This higher level operator API should be compatible with the layer API.
|
||||
|
||||
Let's explain using an example. Suppose that we are going to compose the FC using mul and add in Python, we'd like to have Python functions `mul` and `add` defined in module `operator`:
|
||||
|
||||
```python
|
||||
def operator.mul(X1, X2):
|
||||
O = Var()
|
||||
paddle.cpp.create_operator("mul", input={X1, Y1], output=O)
|
||||
return O
|
||||
|
||||
def operator.add(X1, X2):
|
||||
O = Var()
|
||||
paddle.cpp.create_operator("add", input={X1, X2], output=O)
|
||||
return O
|
||||
```
|
||||
|
||||
Above code snippets are automatically generated. Given them, users can define
|
||||
|
||||
```python
|
||||
def layer.fc(X):
|
||||
W = Var()
|
||||
b = Var()
|
||||
return operator.add(operator.mul(X, W), b)
|
||||
```
|
||||
|
||||
If we don't have `operator.mul` and `operator.add`, the definiton of `layer.fc` would be complicated:
|
||||
|
||||
```python
|
||||
def layer.fc(X):
|
||||
W = Var()
|
||||
b = Var()
|
||||
O1 = Var()
|
||||
paddle.cpp.create_operator("mul", input=[X, W], output=O1)
|
||||
O2 = Var()
|
||||
paddle.cpp.create_operator("add", input=[O1, b], output=O2)
|
||||
return O2
|
||||
```
|
||||
|
||||
We'd like to have Python bindings to operators in package `paddle.operator`, and Python compositions of operators in package `paddle.layer`. So we have the following concepts in above illustrative example:
|
||||
|
||||
```
|
||||
| C++ functions/functors | mul | add | | |
|
||||
| C++ operator class | mulOp | addOp | FCOp | |
|
||||
| Python binding | operator.mul | operator.add | operator.fc | |
|
||||
| Python function | | | | layer.fc |
|
||||
```
|
||||
|
||||
This is how we differentiate layer and operators in PaddlePaddle:
|
||||
|
||||
- those defined in C++ and have a lightweighted Python wrapper in module `operators` are operators; whereas
|
||||
- those who don't have C++ implementations but a Python implementation that compose C++ operators are known as layers.
|
@ -0,0 +1,59 @@
|
||||
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`.
|
||||
|
||||
```python
|
||||
import paddle as pd
|
||||
|
||||
x = var()
|
||||
y = var()
|
||||
cond = var()
|
||||
|
||||
b = pd.create_ifop(inputs=[x], output_num=1)
|
||||
with b.true_block():
|
||||
x = b.inputs(0)
|
||||
z = operator.add(x, y)
|
||||
b.set_output(0, operator.softmax(z))
|
||||
|
||||
out = b(cond)
|
||||
```
|
||||
|
||||
If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N:
|
||||
|
||||
```python
|
||||
import paddle as pd
|
||||
|
||||
x = var()
|
||||
y = var()
|
||||
cond = var()
|
||||
default_value = var()
|
||||
b = pd.create_ifelseop(inputs=[x], output_num=1)
|
||||
with b.true_block():
|
||||
x = b.inputs(0)
|
||||
z = operator.add(x, y)
|
||||
b.set_output(0, operator.softmax(z))
|
||||
|
||||
with b.false_block():
|
||||
x = b.inputs(0)
|
||||
z = layer.fc(x)
|
||||
b.set_output(0, operator.softmax(z))
|
||||
|
||||
out = b(cond)
|
||||
```
|
||||
|
||||
If only true_block is set in an IfElseOp, we can have a default value for false as:
|
||||
```python
|
||||
import paddle as pd
|
||||
|
||||
x = var()
|
||||
y = var()
|
||||
cond = var()
|
||||
default_value = var()
|
||||
b = pd.create_ifelseop(inputs=[x], output_num=1, default_value)
|
||||
|
||||
with b.true_block():
|
||||
x = b.inputs(0)
|
||||
z = operator.add(x, y)
|
||||
b.set_output(0, operator.softmax(z))
|
||||
|
||||
out = b(cond)
|
||||
```
|
||||
where default_value is a list of vars for `cond` == False.
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,122 @@
|
||||
# Design Doc: LoD (Level-of-Detail) Tensor
|
||||
|
||||
PaddlePaddle's RNN doesn't require that all instances have the same length. To do so, we introduce an extension to Tensor, namely, LoD Tensor.
|
||||
|
||||
## Challenge of Variable-length Inputs
|
||||
|
||||
People usually represent a mini-batch by a Tensor. For example, a mini-batch of 32 images, each of size 32x32, is a 10x32x32 Tensor. So a transformation, T, of all images can be a matrix multiplication of the 32x32xO-dimensional tensor T and the 10x32x32 Tensor.
|
||||
|
||||
Another example is that each mini-batch contains 32 sentences, where each word is a D-dimensional one-hot vector. If all sentences have the same length L, we can represent this mini-batch by a 32xLxD tensor. However, in most cases, sentences have variable lengths, and we will need an index data structure to record these variable lengths.
|
||||
|
||||
## LoD as a Solution
|
||||
|
||||
### Mini-Batch of variable-length sentenses
|
||||
|
||||
Let's imagine a mini-batch of 3 variable lengths sentences, containing 3, 1, and 2 words respectively. We can represent it by a (3+1+2)xD tensor plus some index information:
|
||||
|
||||
```
|
||||
3
|
||||
3 1 2
|
||||
||| | ||
|
||||
```
|
||||
|
||||
Each `|` represents a D-dimensional word vectors. The number 3 on top indicate 3 sentences, and numbers 3, 1, and 2 on the second level represent the number of words in each sentence.
|
||||
|
||||
### Mini-Batch of variable-length videos
|
||||
|
||||
This approach generalizes to the case where elements are not words, but higher dimensional objects, like images. Suppose that a mini-batch contains videos of the same frame size 640x480. If a mini-batch contains 3 videos of 3, 1, and 2 frames respectively. The underlying tensor is of size (3+1+2)x640x480. The index information illustrates as:
|
||||
|
||||
```
|
||||
3
|
||||
3 1 2
|
||||
口口口 口 口口
|
||||
```
|
||||
|
||||
where each `口` represents an image.
|
||||
|
||||
### Mini-Batch of fixed-size images
|
||||
|
||||
Let's get back to a typical example, image classification, where each mini-batch has M fixed-sized images. The LoD Tensor representation is
|
||||
|
||||
```
|
||||
M
|
||||
1 1 1 1 1
|
||||
口口口口 ... 口
|
||||
```
|
||||
|
||||
The many 1's on the second level seem duplicated. For this particular case of 2 levels and the second level always have length 1, we can ignore the LoD index.
|
||||
|
||||
### Design and summarization
|
||||
|
||||
In summary, as long as that the essential elements (words or images) have the same size, we can represent mini-batches by a LoD Tensor:
|
||||
|
||||
- The underlying tensor has size LxD1xD2x..., where D1xD2... is the size of the essential elements, and
|
||||
- the first dimension size L has an additon property -- a LoD index as a nested vector:
|
||||
|
||||
```c++
|
||||
typedef std::vector<std::vector> > LoD;
|
||||
```
|
||||
|
||||
- The LoD index can is not necessary when there are only two levels and all elements of the second level have length 1.
|
||||
|
||||
## Slicing of LoD Tensor
|
||||
|
||||
Consider that we have a network with three levels of RNN: the top level one handles articles, the second level one handles sentences, and the basic level one handles words. This network requires that mini-batches represented by 4 level LoD Tensor, for example,
|
||||
|
||||
```
|
||||
3
|
||||
3 1 2
|
||||
3 2 4 1 2 3
|
||||
||| || |||| | || |||
|
||||
```
|
||||
|
||||
To allow each level of RNN to handle its input, we define **the slicing of a LoD Tensor is defined as getting the j-th sequence on level i, or the <i,j>-slice**
|
||||
|
||||
For example, the <2,1>-slice of above slice is
|
||||
|
||||
```
|
||||
2
|
||||
||
|
||||
```
|
||||
|
||||
and the <1,2>-slice of above example is
|
||||
|
||||
```
|
||||
2
|
||||
2 3
|
||||
|| |||
|
||||
```
|
||||
|
||||
Let's go on slicing this slice. Its <1,1>-slice is
|
||||
|
||||
```
|
||||
3
|
||||
|||
|
||||
```
|
||||
|
||||
### The General Slicing Algorithm
|
||||
|
||||
The algorithm, with over-simplified data structure, is defined as
|
||||
|
||||
```c++
|
||||
typedef vector<vector<int> > LoD;
|
||||
|
||||
struct LoDTensor {
|
||||
LoD lod_;
|
||||
float* tensor_;
|
||||
};
|
||||
|
||||
LoDTensor Slice(const LoDTensor& lodt, int level, int sequence) {
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
### Slicing the Top Level
|
||||
|
||||
Please be aware that an RNN operator only slices the top level of a LoD Tensor to get the step inputs.
|
||||
|
||||
```c++
|
||||
LoDTensor Slice(const LoDTensor& lodt, int sequence) {
|
||||
|
||||
}
|
||||
```
|
@ -0,0 +1,244 @@
|
||||
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Conv3DLayer.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(conv3d, Conv3DLayer);
|
||||
|
||||
bool Conv3DLayer::init(const LayerMap &layerMap,
|
||||
const ParameterMap ¶meterMap) {
|
||||
if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
|
||||
int index = 0;
|
||||
for (auto &inputConfig : config_.inputs()) {
|
||||
const ConvConfig &conf = inputConfig.conv_conf();
|
||||
M_.push_back(numFilters_ / conf.groups());
|
||||
K_.push_back(filterPixels_[index] * filterChannels_[index]);
|
||||
|
||||
// create a new weight
|
||||
size_t height, width;
|
||||
width = filterPixels_[index] * filterChannels_[index];
|
||||
height = numFilters_;
|
||||
CHECK_EQ(parameters_[index]->getSize(), width * height);
|
||||
Weight *w = new Weight(height, width, parameters_[index]);
|
||||
weights_.emplace_back(w);
|
||||
++index;
|
||||
}
|
||||
if (biasParameter_.get()) {
|
||||
if (sharedBiases_) {
|
||||
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
|
||||
biases_ =
|
||||
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
|
||||
} else {
|
||||
biases_ =
|
||||
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t Conv3DLayer::getSize() {
|
||||
CHECK_NE(inputLayers_.size(), 0UL);
|
||||
outputH_.clear();
|
||||
outputW_.clear();
|
||||
outputD_.clear();
|
||||
N_.clear();
|
||||
size_t layerSize = 0;
|
||||
for (size_t i = 0; i < inputLayers_.size(); ++i) {
|
||||
outputW_.push_back(outputSize(
|
||||
imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
|
||||
outputH_.push_back(outputSize(
|
||||
imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
|
||||
outputD_.push_back(outputSize(
|
||||
imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
|
||||
|
||||
N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
|
||||
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
|
||||
layerSize += N_[i] * numFilters_;
|
||||
}
|
||||
getOutput().setFrameHeight(outputH_[0]);
|
||||
getOutput().setFrameWidth(outputW_[0]);
|
||||
getOutput().setFrameDepth(outputD_[0]);
|
||||
return layerSize;
|
||||
}
|
||||
|
||||
void Conv3DLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
|
||||
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
|
||||
int outWidth = getSize();
|
||||
resetOutput(batchSize, outWidth);
|
||||
|
||||
for (size_t i = 0; i != inputLayers_.size(); ++i) {
|
||||
REGISTER_TIMER_INFO("FwdConv3D", getName().c_str());
|
||||
const MatrixPtr &inMat = getInputValue(i);
|
||||
const MatrixPtr &outMat = getOutputValue();
|
||||
int M = M_[i];
|
||||
int N = N_[i];
|
||||
int K = K_[i];
|
||||
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
|
||||
MatrixPtr wMat = weights_[i]->getW();
|
||||
for (int n = 0; n < batchSize; ++n) {
|
||||
colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
|
||||
channels_[i],
|
||||
imgSizeD_[i],
|
||||
imgSizeH_[i],
|
||||
imgSizeW_[i],
|
||||
filterSizeZ_[i],
|
||||
filterSizeY_[i],
|
||||
filterSize_[i],
|
||||
strideZ_[i],
|
||||
strideY_[i],
|
||||
stride_[i],
|
||||
paddingZ_[i],
|
||||
paddingY_[i],
|
||||
padding_[i]);
|
||||
|
||||
real *outData = outMat->getData() + n * outMat->getStride();
|
||||
MatrixPtr outMatSub =
|
||||
Matrix::create(outData, groups_[i] * M, N, false, useGpu_);
|
||||
for (int g = 0; g < groups_[i]; g++) {
|
||||
MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
|
||||
MatrixPtr in = colBuf_->subMatrix(g * K, K);
|
||||
MatrixPtr out = outMatSub->subMatrix(g * M, M);
|
||||
out->mul(*wMatSub, *in, 1.0, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nullptr != this->biasParameter_) {
|
||||
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
|
||||
this->addBias();
|
||||
}
|
||||
forwardActivation();
|
||||
}
|
||||
|
||||
void Conv3DLayer::backward(const UpdateCallback &callback) {
|
||||
backwardActivation();
|
||||
|
||||
if (biases_ && biases_->getWGrad()) {
|
||||
bpropBiases();
|
||||
biases_->getParameterPtr()->incUpdate(callback);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i != inputLayers_.size(); ++i) {
|
||||
REGISTER_TIMER_INFO("BwdConv3D", getName().c_str());
|
||||
if (weights_[i]->getWGrad()) {
|
||||
bpropWeights(i);
|
||||
}
|
||||
if (getInputGrad(i)) {
|
||||
bpropData(i);
|
||||
}
|
||||
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
|
||||
weights_[i]->getParameterPtr()->incUpdate(callback);
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3DLayer::bpropWeights(int i) {
|
||||
int M = M_[i];
|
||||
int N = N_[i];
|
||||
int K = K_[i];
|
||||
const MatrixPtr &inMat = getInputValue(i);
|
||||
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
|
||||
MatrixPtr wGradMat = weights_[i]->getWGrad();
|
||||
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
|
||||
for (int n = 0; n < batchSize; ++n) {
|
||||
colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
|
||||
channels_[i],
|
||||
imgSizeD_[i],
|
||||
imgSizeH_[i],
|
||||
imgSizeW_[i],
|
||||
filterSizeZ_[i],
|
||||
filterSizeY_[i],
|
||||
filterSize_[i],
|
||||
strideZ_[i],
|
||||
strideY_[i],
|
||||
stride_[i],
|
||||
paddingZ_[i],
|
||||
paddingY_[i],
|
||||
padding_[i]);
|
||||
|
||||
real *outGradData =
|
||||
getOutputGrad()->getData() + n * getOutputGrad()->getStride();
|
||||
MatrixPtr outGradSub =
|
||||
Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_);
|
||||
for (int g = 0; g < groups_[i]; ++g) {
|
||||
MatrixPtr inMatSub = colBuf_->subMatrix(g * K, K);
|
||||
MatrixPtr outG = outGradSub->subMatrix(g * M, M);
|
||||
MatrixPtr wGradSub = wGradMat->subMatrix(g * M, M);
|
||||
wGradSub->mul(*outG, *(inMatSub->getTranspose()), 1.0, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3DLayer::bpropData(int i) {
|
||||
int M = M_[i];
|
||||
int N = N_[i];
|
||||
int K = K_[i];
|
||||
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
|
||||
MatrixPtr wMat = weights_[i]->getW();
|
||||
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
|
||||
for (int n = 0; n < batchSize; ++n) {
|
||||
real *outGradData =
|
||||
getOutputGrad()->getData() + n * getOutputGrad()->getStride();
|
||||
real *preGradData =
|
||||
getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
|
||||
MatrixPtr outGradSub =
|
||||
Matrix::create(outGradData, M * groups_[i], N, false, useGpu_);
|
||||
for (int g = 0; g < groups_[i]; ++g) {
|
||||
MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
|
||||
MatrixPtr outG = outGradSub->subMatrix(g * M, M);
|
||||
MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K);
|
||||
inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0);
|
||||
}
|
||||
colBuf_->col2Vol(preGradData,
|
||||
channels_[i],
|
||||
imgSizeD_[i],
|
||||
imgSizeH_[i],
|
||||
imgSizeW_[i],
|
||||
filterSizeZ_[i],
|
||||
filterSizeY_[i],
|
||||
filterSize_[i],
|
||||
strideZ_[i],
|
||||
strideY_[i],
|
||||
stride_[i],
|
||||
paddingZ_[i],
|
||||
paddingY_[i],
|
||||
padding_[i],
|
||||
1.0,
|
||||
1.0);
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3DLayer::bpropBiases() {
|
||||
MatrixPtr outGradMat = getOutputGrad();
|
||||
if (this->sharedBiases_) {
|
||||
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
|
||||
} else {
|
||||
biases_->getWGrad()->collectBias(*outGradMat, 1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3DLayer::addBias() {
|
||||
MatrixPtr outMat = getOutputValue();
|
||||
if (this->sharedBiases_) {
|
||||
outMat->addSharedBias(*(biases_->getW()), 1.0f);
|
||||
} else {
|
||||
outMat->addBias(*(biases_->getW()), 1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "ConvBaseLayer.h"
|
||||
#include "paddle/math/MathUtils.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief A subclass of convolution layer.
|
||||
* This layer expands input and use matrix multiplication to
|
||||
* calculate convolution operation.
|
||||
*/
|
||||
class Conv3DLayer : public ConvBaseLayer {
|
||||
public:
|
||||
explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
|
||||
~Conv3DLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
|
||||
|
||||
void forward(PassType passType);
|
||||
void addBias();
|
||||
void backward(const UpdateCallback& callback);
|
||||
void bpropBiases();
|
||||
void bpropData(int i);
|
||||
void bpropWeights(int i);
|
||||
size_t getSize();
|
||||
|
||||
protected:
|
||||
// Figure out the dimensions for individual gemms.
|
||||
IntV M_; /// numFilters_ / filter_group_;
|
||||
IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_
|
||||
IntV K_; /// outputD_ * outputH_ * outputW_
|
||||
MatrixPtr colBuf_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue