|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/transfer_scope_cache.h"
|
|
|
|
|
#include "paddle/fluid/inference/tests/api/tester_helper.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -228,5 +229,44 @@ TEST(Analyzer_bert, compare_determine) {
|
|
|
|
|
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
|
|
|
|
|
inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Analyzer_bert, transfer_scope_cache) {
|
|
|
|
|
AnalysisConfig config;
|
|
|
|
|
SetConfig(&config);
|
|
|
|
|
|
|
|
|
|
std::vector<PaddleTensor> input, output;
|
|
|
|
|
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
|
|
|
|
|
|
|
|
|
|
int threads_num = 10;
|
|
|
|
|
std::vector<std::thread> threads;
|
|
|
|
|
std::unordered_set<std::unordered_set<paddle::framework::Scope *> *>
|
|
|
|
|
global_transfer_scope_cache;
|
|
|
|
|
std::unordered_set<std::unordered_map<size_t, paddle::framework::Scope *> *>
|
|
|
|
|
global_transfer_data_cache;
|
|
|
|
|
|
|
|
|
|
std::ifstream fin(FLAGS_infer_data);
|
|
|
|
|
std::string line;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < threads_num; i++) {
|
|
|
|
|
threads.emplace_back([&, i]() {
|
|
|
|
|
std::getline(fin, line);
|
|
|
|
|
ParseLine(line, &input);
|
|
|
|
|
predictor->Run(input, &output, FLAGS_batch_size);
|
|
|
|
|
global_transfer_scope_cache.insert(
|
|
|
|
|
&paddle::framework::global_transfer_scope_cache());
|
|
|
|
|
global_transfer_data_cache.insert(
|
|
|
|
|
&paddle::framework::global_transfer_data_cache());
|
|
|
|
|
});
|
|
|
|
|
threads[0].join();
|
|
|
|
|
threads.clear();
|
|
|
|
|
std::vector<PaddleTensor>().swap(input);
|
|
|
|
|
}
|
|
|
|
|
// Since paddle::framework::global_transfer_scope_cache() and
|
|
|
|
|
// paddle::framework::global_transfer_data_cache() are thread_local,
|
|
|
|
|
// their pointer should be different among different thread id.
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num);
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace inference
|
|
|
|
|
} // namespace paddle
|
|
|
|
|