You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
2.9 KiB
104 lines
2.9 KiB
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
|
|
#pragma once
|
|
|
|
#include "paddle/parameter/Parameter.h"
|
|
#include "ModelConfig.pb.h"
|
|
#include "Layer.h"
|
|
|
|
namespace paddle {
|
|
|
|
// Macro for registering a projection type
|
|
// Example: REGISTER_LAYER(fc, FullMatrixProjection);
|
|
#define REGISTER_PROJECTION(__type_name, __class_name) \
|
|
static InitFunction __reg_type_##__type_name([]() { \
|
|
Projection::registrar_.registerClass<__class_name>(#__type_name); \
|
|
})
|
|
|
|
/**
|
|
* A projection takes one Argument as input, calculate the result and add it
|
|
* to output Argument.
|
|
*/
|
|
class Projection {
|
|
public:
|
|
static Projection* create(const ProjectionConfig& config,
|
|
ParameterPtr parameter, bool useGpu);
|
|
|
|
Projection(const ProjectionConfig& config, ParameterPtr parameter,
|
|
bool useGpu)
|
|
: config_(config), parameter_(parameter), useGpu_(useGpu) {}
|
|
|
|
virtual ~Projection() {}
|
|
|
|
const std::string& getName() const { return config_.name(); }
|
|
|
|
/// Register a projection
|
|
static ClassRegistrar<Projection, ProjectionConfig, ParameterPtr, bool>
|
|
registrar_;
|
|
|
|
/**
|
|
* Forward propagation. If backward() will be called, in and out must be kept valid until then.
|
|
* @param in input of projection
|
|
* @param out output of projection
|
|
* @param passType PASS_TRAIN of PASS_TEST
|
|
*/
|
|
void forward(const Argument* in, const Argument* out, PassType passType) {
|
|
in_ = in;
|
|
out_ = out;
|
|
passType_ = passType;
|
|
forward();
|
|
}
|
|
|
|
virtual void prefetch(const Argument* in) {}
|
|
virtual void forward() = 0;
|
|
virtual void backward(const UpdateCallback& callback) = 0;
|
|
|
|
/**
|
|
* See comment in Layer.h for the function with the same name.
|
|
*/
|
|
virtual void resetState() {}
|
|
|
|
/**
|
|
* Set layer state.
|
|
*/
|
|
virtual void setState(LayerStatePtr state) {}
|
|
|
|
/**
|
|
* Get layer state. A copy of internal state is returned.
|
|
*/
|
|
virtual LayerStatePtr getState() { return nullptr; }
|
|
|
|
/**
|
|
* Get output size of projection.
|
|
*/
|
|
size_t getOutputSize() const { return config_.output_size(); }
|
|
|
|
protected:
|
|
/// Config of projection
|
|
ProjectionConfig config_;
|
|
/// Parameter of projection
|
|
ParameterPtr parameter_;
|
|
bool useGpu_;
|
|
|
|
/// Store `in` passed to forward()
|
|
const Argument* in_;
|
|
/// Store `out` passed to forward()
|
|
const Argument* out_;
|
|
/// Store `passType` passed to forward()
|
|
PassType passType_;
|
|
};
|
|
} // namespace paddle
|