tonyyang-svail-feed-op-desgin
commit
874bcb3030
@ -0,0 +1,105 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_info.h"
|
||||
#include "paddle/framework/op_proto_maker.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
enum OpInfoFillType {
|
||||
kOperator = 0,
|
||||
kOpProtoAndCheckerMaker = 1,
|
||||
kGradOpDescMaker = 2
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OpInfoFillTypeID {
|
||||
static constexpr OpInfoFillType ID() {
|
||||
return std::is_base_of<OperatorBase, T>::value
|
||||
? kOperator
|
||||
: (std::is_base_of<OpProtoAndCheckerMaker, T>::value
|
||||
? kOpProtoAndCheckerMaker
|
||||
: (std::is_base_of<GradOpDescMakerBase, T>::value
|
||||
? kGradOpDescMaker
|
||||
: static_cast<OpInfoFillType>(-1)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
|
||||
struct OpInfoFiller;
|
||||
|
||||
template <size_t I, bool at_end, typename... ARGS>
|
||||
class OperatorRegistrarRecursive;
|
||||
|
||||
template <size_t I, typename... ARGS>
|
||||
class OperatorRegistrarRecursive<I, false, ARGS...> {
|
||||
public:
|
||||
using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
|
||||
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
|
||||
OpInfoFiller<T> fill;
|
||||
fill(op_type, info);
|
||||
constexpr auto size = sizeof...(ARGS);
|
||||
OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
|
||||
info);
|
||||
(void)(reg);
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t I, typename... ARGS>
|
||||
class OperatorRegistrarRecursive<I, true, ARGS...> {
|
||||
public:
|
||||
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OpInfoFiller<T, kOperator> {
|
||||
void operator()(const char* op_type, OpInfo* info) const {
|
||||
info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
|
||||
const VariableNameMap& outputs,
|
||||
const AttributeMap& attrs) {
|
||||
return new T(type, inputs, outputs, attrs);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
|
||||
void operator()(const char* op_type, OpInfo* info) const {
|
||||
info->proto_ = new OpProto;
|
||||
info->checker_ = new OpAttrChecker();
|
||||
auto maker = T(info->proto_, info->checker_);
|
||||
maker.Validate();
|
||||
info->proto_->set_type(op_type);
|
||||
PADDLE_ENFORCE(
|
||||
info->proto_->IsInitialized(),
|
||||
"Fail to initialize %s's OpProto, because %s is not initialized",
|
||||
op_type, info->proto_->InitializationErrorString());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OpInfoFiller<T, kGradOpDescMaker> {
|
||||
void operator()(const char* op_type, OpInfo* info) const {
|
||||
info->grad_op_maker_ = new T();
|
||||
}
|
||||
};
|
||||
} // namespace details
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,27 @@
|
||||
type: "nn"
|
||||
layers {
|
||||
name: "input"
|
||||
type: "data"
|
||||
size: 300
|
||||
active_type: ""
|
||||
}
|
||||
layers {
|
||||
name: "__resize_0__"
|
||||
type: "resize"
|
||||
size: 150
|
||||
active_type: ""
|
||||
inputs {
|
||||
input_layer_name: "input"
|
||||
}
|
||||
}
|
||||
input_layer_names: "input"
|
||||
output_layer_names: "__resize_0__"
|
||||
sub_models {
|
||||
name: "root"
|
||||
layer_names: "input"
|
||||
layer_names: "__resize_0__"
|
||||
input_layer_names: "input"
|
||||
output_layer_names: "__resize_0__"
|
||||
is_recurrent_layer_group: false
|
||||
}
|
||||
|
@ -0,0 +1,6 @@
|
||||
from paddle.trainer_config_helpers import *
|
||||
|
||||
data = data_layer(name='input', size=300)
|
||||
resized = resize_layer(input=data, size=150)
|
||||
|
||||
outputs(resized)
|
Loading…
Reference in new issue