revert-4814-Add_sequence_project_op
Yu Yang 7 years ago
parent 578a357b61
commit ff8766e910

@ -378,6 +378,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/);
std::cerr << grad_fc.DebugString() << std::endl;
EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add

@ -85,6 +85,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info->proto_ = new OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
std::cerr << "Assign Maker " << op_type << std::endl;
maker.Validate();
info->proto_->set_type(op_type);
PADDLE_ENFORCE(

@ -98,7 +98,7 @@ class OpDescBind {
std::vector<typename MapType::key_type> ret_val;
ret_val.reserve(map.size());
std::transform(
map.begin(), map.end(), ret_val.begin(),
map.begin(), map.end(), std::back_inserter(ret_val),
[](const typename MapType::value_type &pair) { return pair.first; });
return ret_val;
}

@ -42,19 +42,11 @@ struct OpInfo {
return *proto_;
}
const OpAttrChecker& Checker() const {
PADDLE_ENFORCE_NOT_NULL(checker_,
"Operator Checker has not been registered");
return *checker_;
}
const OpCreator& Creator() const {
PADDLE_ENFORCE_NOT_NULL(creator_,
"Operator Creator has not been registered");
return creator_;
}
bool HasGradientOp() const { return !grad_op_type_.empty(); }
};
class OpInfoMap {

@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);

@ -23,7 +23,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs) {
auto& info = OpInfoMap::Instance().Get(type);
info.Checker().Check(attrs);
if (info.checker_ != nullptr) {
info.checker_->Check(attrs);
}
auto op = info.Creator()(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op);
}

Loading…
Cancel
Save