tonyyang-svail-feed-op-desgin
commit
874bcb3030
@ -0,0 +1,105 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/op_info.h"
|
||||||
|
#include "paddle/framework/op_proto_maker.h"
|
||||||
|
#include "paddle/framework/operator.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
enum OpInfoFillType {
|
||||||
|
kOperator = 0,
|
||||||
|
kOpProtoAndCheckerMaker = 1,
|
||||||
|
kGradOpDescMaker = 2
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct OpInfoFillTypeID {
|
||||||
|
static constexpr OpInfoFillType ID() {
|
||||||
|
return std::is_base_of<OperatorBase, T>::value
|
||||||
|
? kOperator
|
||||||
|
: (std::is_base_of<OpProtoAndCheckerMaker, T>::value
|
||||||
|
? kOpProtoAndCheckerMaker
|
||||||
|
: (std::is_base_of<GradOpDescMakerBase, T>::value
|
||||||
|
? kGradOpDescMaker
|
||||||
|
: static_cast<OpInfoFillType>(-1)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
|
||||||
|
struct OpInfoFiller;
|
||||||
|
|
||||||
|
template <size_t I, bool at_end, typename... ARGS>
|
||||||
|
class OperatorRegistrarRecursive;
|
||||||
|
|
||||||
|
template <size_t I, typename... ARGS>
|
||||||
|
class OperatorRegistrarRecursive<I, false, ARGS...> {
|
||||||
|
public:
|
||||||
|
using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
|
||||||
|
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
|
||||||
|
OpInfoFiller<T> fill;
|
||||||
|
fill(op_type, info);
|
||||||
|
constexpr auto size = sizeof...(ARGS);
|
||||||
|
OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
|
||||||
|
info);
|
||||||
|
(void)(reg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <size_t I, typename... ARGS>
|
||||||
|
class OperatorRegistrarRecursive<I, true, ARGS...> {
|
||||||
|
public:
|
||||||
|
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct OpInfoFiller<T, kOperator> {
|
||||||
|
void operator()(const char* op_type, OpInfo* info) const {
|
||||||
|
info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
|
||||||
|
const VariableNameMap& outputs,
|
||||||
|
const AttributeMap& attrs) {
|
||||||
|
return new T(type, inputs, outputs, attrs);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
|
||||||
|
void operator()(const char* op_type, OpInfo* info) const {
|
||||||
|
info->proto_ = new OpProto;
|
||||||
|
info->checker_ = new OpAttrChecker();
|
||||||
|
auto maker = T(info->proto_, info->checker_);
|
||||||
|
maker.Validate();
|
||||||
|
info->proto_->set_type(op_type);
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
info->proto_->IsInitialized(),
|
||||||
|
"Fail to initialize %s's OpProto, because %s is not initialized",
|
||||||
|
op_type, info->proto_->InitializationErrorString());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct OpInfoFiller<T, kGradOpDescMaker> {
|
||||||
|
void operator()(const char* op_type, OpInfo* info) const {
|
||||||
|
info->grad_op_maker_ = new T();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace details
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,27 @@
|
|||||||
|
type: "nn"
|
||||||
|
layers {
|
||||||
|
name: "input"
|
||||||
|
type: "data"
|
||||||
|
size: 300
|
||||||
|
active_type: ""
|
||||||
|
}
|
||||||
|
layers {
|
||||||
|
name: "__resize_0__"
|
||||||
|
type: "resize"
|
||||||
|
size: 150
|
||||||
|
active_type: ""
|
||||||
|
inputs {
|
||||||
|
input_layer_name: "input"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_layer_names: "input"
|
||||||
|
output_layer_names: "__resize_0__"
|
||||||
|
sub_models {
|
||||||
|
name: "root"
|
||||||
|
layer_names: "input"
|
||||||
|
layer_names: "__resize_0__"
|
||||||
|
input_layer_names: "input"
|
||||||
|
output_layer_names: "__resize_0__"
|
||||||
|
is_recurrent_layer_group: false
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,6 @@
|
|||||||
|
from paddle.trainer_config_helpers import *
|
||||||
|
|
||||||
|
data = data_layer(name='input', size=300)
|
||||||
|
resized = resize_layer(input=data, size=150)
|
||||||
|
|
||||||
|
outputs(resized)
|
Loading…
Reference in new issue