Initialize the benchmark tester for operator. (#15772)
* Initialize the benchmark tester for operator. test=develop * Rearrange the codes. test=developrevert-15774-anakin_subgraph_engine
parent
676995c86c
commit
7d96c74ab2
@ -0,0 +1,3 @@
|
||||
cc_test(op_tester SRCS op_tester.cc op_tester_config.cc
|
||||
DEPS memory timer framework_proto proto_desc lod_tensor op_registry
|
||||
device_context scope ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,69 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/op_desc.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/benchmark/op_tester_config.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace benchmark {
|
||||
|
||||
class OpTester {
|
||||
public:
|
||||
OpTester() {}
|
||||
|
||||
void Init(const std::string &filename);
|
||||
void Init(const OpTesterConfig &config);
|
||||
|
||||
void Run();
|
||||
|
||||
std::string DebugString();
|
||||
|
||||
private:
|
||||
std::vector<std::string> GetOpProtoInputNames();
|
||||
std::vector<std::string> GetOpProtoOutputNames();
|
||||
|
||||
void CreateInputVarDesc();
|
||||
void CreateOutputVarDesc();
|
||||
|
||||
framework::VarDesc *Var(const std::string &name);
|
||||
void CreateVariables(framework::Scope *scope);
|
||||
|
||||
template <typename T>
|
||||
void SetupTensor(framework::LoDTensor *input,
|
||||
const std::vector<int64_t> &shape, T lower, T upper);
|
||||
|
||||
void RunImpl();
|
||||
|
||||
private:
|
||||
OpTesterConfig config_;
|
||||
std::string type_;
|
||||
framework::OpDesc op_desc_;
|
||||
std::unordered_map<std::string, std::unique_ptr<framework::VarDesc>> vars_;
|
||||
std::vector<std::string> inputs_;
|
||||
std::vector<std::string> outputs_;
|
||||
std::unique_ptr<framework::OperatorBase> op_;
|
||||
platform::Place place_;
|
||||
std::unique_ptr<framework::Scope> scope_;
|
||||
};
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,114 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/benchmark/op_tester_config.h"
|
||||
#include <fstream>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace benchmark {
|
||||
|
||||
static const char kStartSeparator[] = "{";
|
||||
static const char kEndSeparator[] = "}";
|
||||
static const char kSepBetweenItems[] = ";";
|
||||
|
||||
static bool StartWith(const std::string& str, const std::string& substr) {
|
||||
return str.find(substr) == 0;
|
||||
}
|
||||
|
||||
static bool EndWith(const std::string& str, const std::string& substr) {
|
||||
return str.rfind(substr) == (str.length() - substr.length());
|
||||
}
|
||||
|
||||
static void EraseEndSep(std::string* str) {
|
||||
std::string substr = kSepBetweenItems;
|
||||
if (EndWith(*str, substr)) {
|
||||
str->erase(str->length() - substr.length(), str->length());
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int64_t> ParseDims(std::string dims_str) {
|
||||
std::vector<int64_t> dims;
|
||||
std::string token;
|
||||
std::istringstream token_stream(dims_str);
|
||||
while (std::getline(token_stream, token, 'x')) {
|
||||
dims.push_back(std::stoi(token));
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
OpInputConfig::OpInputConfig(std::istream& is) {
|
||||
std::string sep;
|
||||
is >> sep;
|
||||
if (sep == kStartSeparator) {
|
||||
while (sep != kEndSeparator) {
|
||||
is >> sep;
|
||||
if (sep == "name" || sep == "name:") {
|
||||
is >> name;
|
||||
EraseEndSep(&name);
|
||||
} else if (sep == "dims" || sep == "dims:") {
|
||||
std::string dims_str;
|
||||
is >> dims_str;
|
||||
dims = ParseDims(dims_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OpTesterConfig::OpTesterConfig(const std::string& filename) {
|
||||
std::ifstream fin(filename, std::ios::in | std::ios::binary);
|
||||
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s",
|
||||
filename.c_str());
|
||||
|
||||
Init(fin);
|
||||
}
|
||||
|
||||
void OpTesterConfig::Init(std::istream& is) {
|
||||
std::string sep;
|
||||
is >> sep;
|
||||
if (sep == kStartSeparator) {
|
||||
while (sep != kEndSeparator) {
|
||||
is >> sep;
|
||||
if (sep == "op_type" || sep == "op_type:") {
|
||||
is >> op_type;
|
||||
} else if (sep == "device_id" || sep == "device_id:") {
|
||||
is >> device_id;
|
||||
} else if (sep == "repeat" || sep == "repeat:") {
|
||||
is >> repeat;
|
||||
} else if (sep == "profile" || sep == "profile:") {
|
||||
is >> profile;
|
||||
} else if (sep == "print_debug_string" || sep == "print_debug_string:") {
|
||||
is >> print_debug_string;
|
||||
} else if (sep == "input" || sep == "input:") {
|
||||
OpInputConfig input_config(is);
|
||||
inputs.push_back(input_config);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const OpInputConfig* OpTesterConfig::GetInput(const std::string& name) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (inputs[i].name == name) {
|
||||
return &inputs[i];
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <istream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace benchmark {
|
||||
|
||||
struct OpInputConfig {
|
||||
OpInputConfig() {}
|
||||
explicit OpInputConfig(std::istream& is);
|
||||
|
||||
std::string name;
|
||||
std::vector<int64_t> dims;
|
||||
};
|
||||
|
||||
struct OpTesterConfig {
|
||||
OpTesterConfig() {}
|
||||
explicit OpTesterConfig(const std::string& filename);
|
||||
void Init(std::istream& is);
|
||||
|
||||
const OpInputConfig* GetInput(const std::string& name);
|
||||
|
||||
std::string op_type;
|
||||
std::vector<OpInputConfig> inputs;
|
||||
int device_id{-1}; // CPU: -1
|
||||
int repeat{1};
|
||||
int profile{0};
|
||||
int print_debug_string{0};
|
||||
double runtime{0.0};
|
||||
};
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue