anakin subgraph engine (#15774)
* add anakin subgraph engine * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * add initial op converter * update * update * fix op register compile error * update test=develop * updaterevert-16045-imperative_remove_desc
parent
212242c4e4
commit
afc3fcd509
@ -0,0 +1,4 @@
|
||||
cc_library(anakin_engine SRCS engine.cc)
|
||||
target_link_libraries(anakin_engine anakin anakin_saber_common)
|
||||
cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine)
|
||||
add_subdirectory(convert)
|
@ -0,0 +1,2 @@
|
||||
cc_library(anakin_op_converter SRCS fc.cc registrar.cc DEPS anakin_engine framework_proto scope)
|
||||
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op)
|
@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/anakin/convert/fc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
void FcOpConverter::operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope, bool test_mode) {
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1);
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("Out").size(), 1);
|
||||
|
||||
auto x_name = op_desc.Input("X").front();
|
||||
PADDLE_ENFORCE(x_name.size() > 0);
|
||||
auto *y_v = scope.FindVar(op_desc.Input("Y").front());
|
||||
PADDLE_ENFORCE_NOT_NULL(y_v);
|
||||
auto *y_t = y_v->GetMutable<framework::LoDTensor>();
|
||||
|
||||
auto shape = framework::vectorize2int(y_t->dims());
|
||||
}
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
class FcOpConverter : public AnakinOpConverter {
|
||||
public:
|
||||
FcOpConverter() = default;
|
||||
|
||||
virtual void operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope,
|
||||
bool test_mode) override;
|
||||
virtual ~FcOpConverter() {}
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
static Registrar<FcOpConverter> register_fc_op_converter("fc");
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,112 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "framework/core/types.h"
|
||||
#include "paddle/fluid/framework/block_desc.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/inference/anakin/convert/registrar.h"
|
||||
#include "paddle/fluid/inference/anakin/engine.h"
|
||||
#include "paddle/fluid/inference/utils/singleton.h"
|
||||
#include "saber/saber_types.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
using AnakinNvEngine =
|
||||
AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>;
|
||||
|
||||
class AnakinOpConverter {
|
||||
public:
|
||||
AnakinOpConverter() = default;
|
||||
|
||||
virtual void operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope, bool test_mode) {}
|
||||
void ConvertOp(const framework::proto::OpDesc &op,
|
||||
const std::unordered_set<std::string> ¶meters,
|
||||
const framework::Scope &scope, AnakinNvEngine *engine,
|
||||
bool test_mode = false) {
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
std::string op_type = op_desc.Type();
|
||||
std::shared_ptr<AnakinOpConverter> it{nullptr};
|
||||
|
||||
if (op_type == "mul") {
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
|
||||
std::string Y = op_desc.Input("Y")[0];
|
||||
std::cout << Y << parameters.count(Y) << std::endl;
|
||||
if (parameters.count(Y)) {
|
||||
it = OpRegister::instance()->Get("fc");
|
||||
}
|
||||
}
|
||||
|
||||
if (!it) {
|
||||
it = OpRegister::instance()->Get(op_type);
|
||||
}
|
||||
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_type);
|
||||
it->SetEngine(engine);
|
||||
(*it)(op, scope, test_mode);
|
||||
}
|
||||
|
||||
void ConvertBlock(const framework::proto::BlockDesc &block,
|
||||
const std::unordered_set<std::string> ¶meters,
|
||||
const framework::Scope &scope, AnakinNvEngine *engine) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
for (auto i = 0; i < block.ops_size(); i++) {
|
||||
auto &op = block.ops(i);
|
||||
ConvertOp(op, parameters, scope, engine);
|
||||
}
|
||||
}
|
||||
void SetEngine(AnakinNvEngine *engine) { engine_ = engine; }
|
||||
virtual ~AnakinOpConverter() {}
|
||||
|
||||
protected:
|
||||
bool test_mode_;
|
||||
AnakinNvEngine *engine_{nullptr};
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, AnakinOpConverter *> converters_;
|
||||
framework::Scope *scope_{nullptr};
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \
|
||||
struct anakin_##op_type__##_converter \
|
||||
: public ::paddle::framework::Registrar { \
|
||||
anakin_##op_type__##_converter() { \
|
||||
::paddle::inference:: \
|
||||
Registry<paddle::inference::anakin::AnakinOpConverter>::Register< \
|
||||
::paddle::inference::anakin::Converter__>(#op_type__); \
|
||||
} \
|
||||
}; \
|
||||
anakin_##op_type__##_converter anakin_##op_type__##_converter__; \
|
||||
int TouchConverterRegister_anakin_##op_type__() { \
|
||||
anakin_##op_type__##_converter__.Touch(); \
|
||||
return 0; \
|
||||
}
|
||||
|
||||
#define USE_ANAKIN_CONVERTER(op_type__) \
|
||||
extern int TouchConverterRegister_anakin_##op_type__(); \
|
||||
static int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \
|
||||
TouchConverterRegister_anakin_##op_type__();
|
@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/anakin/convert/registrar.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
std::shared_ptr<AnakinOpConverter> OpRegister::Get(const std::string &name) {
|
||||
auto it = registry_.find(name);
|
||||
if (it == registry_.end()) return nullptr;
|
||||
return it->second();
|
||||
}
|
||||
|
||||
OpRegister *OpRegister::instance() {
|
||||
static OpRegister factory;
|
||||
return &factory;
|
||||
}
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
class AnakinOpConverter;
|
||||
|
||||
class OpRegister {
|
||||
public:
|
||||
OpRegister() = default;
|
||||
std::shared_ptr<AnakinOpConverter> Get(const std::string &name);
|
||||
static OpRegister *instance();
|
||||
void OpRegisterFn(const std::string &name,
|
||||
std::function<std::shared_ptr<AnakinOpConverter>()> fn) {
|
||||
registry_[name] = fn;
|
||||
}
|
||||
|
||||
private:
|
||||
using RegisterFnType = std::function<std::shared_ptr<AnakinOpConverter>()>;
|
||||
std::map<std::string, std::function<std::shared_ptr<AnakinOpConverter>()>>
|
||||
registry_;
|
||||
};
|
||||
|
||||
template <typename T, typename... Args>
|
||||
class Registrar {
|
||||
public:
|
||||
Registrar(const std::string &name, Args... args) {
|
||||
std::shared_ptr<AnakinOpConverter> converter =
|
||||
std::make_shared<T>(std::move(args)...);
|
||||
OpRegister::instance()->OpRegisterFn(name,
|
||||
[converter]() { return converter; });
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/inference/anakin/convert/fc.h"
|
||||
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
TEST(fc_op, test) {
|
||||
auto it = OpRegister::instance()->Get("fc");
|
||||
ASSERT_TRUE(it != nullptr);
|
||||
|
||||
std::unordered_set<std::string> parameters({"mul_y"});
|
||||
framework::Scope scope;
|
||||
AnakinConvertValidation validator(parameters, scope);
|
||||
validator.DeclInputVar("mul_x", {1, 1, 1, 1});
|
||||
validator.DeclParamVar("mul_y", {1, 1, 1, 2});
|
||||
validator.DeclOutputVar("mul_out", {1, 1, 1, 2});
|
||||
|
||||
// Prepare Op description
|
||||
framework::OpDesc desc;
|
||||
desc.SetType("mul");
|
||||
desc.SetInput("X", {"mul_x"});
|
||||
desc.SetInput("Y", {"mul_y"});
|
||||
desc.SetOutput("Out", {"mul_out"});
|
||||
int num_flatten_dims = 3;
|
||||
desc.SetAttr("x_num_col_dims", num_flatten_dims);
|
||||
validator.SetOp(*desc.Proto());
|
||||
|
||||
validator.Execute(10);
|
||||
}
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
USE_OP(mul);
|
@ -0,0 +1,169 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/inference/anakin/engine.h"
|
||||
#include "paddle/fluid/inference/analysis/helper.h"
|
||||
#include "paddle/fluid/inference/utils/singleton.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
using anakin::graph::GraphGlobalMem;
|
||||
using anakin::AK_FLOAT;
|
||||
using anakin::Precision;
|
||||
using anakin::saber::NV;
|
||||
using anakin::saber::X86;
|
||||
using anakin::saber::Shape;
|
||||
using anakin::PBlock;
|
||||
using anakin::PTuple;
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
/*
|
||||
* Get a random float value between [low, high]
|
||||
*/
|
||||
float random(float low, float high) {
|
||||
static std::random_device rd;
|
||||
static std::mt19937 mt(rd());
|
||||
std::uniform_real_distribution<double> dist(low, high);
|
||||
return dist(mt);
|
||||
}
|
||||
|
||||
void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
|
||||
const platform::DeviceContext& ctx) {
|
||||
auto dims = tensor->dims();
|
||||
size_t num_elements = analysis::AccuDims(dims, dims.size());
|
||||
PADDLE_ENFORCE_GT(num_elements, 0);
|
||||
|
||||
platform::CPUPlace cpu_place;
|
||||
framework::LoDTensor temp_tensor;
|
||||
temp_tensor.Resize(dims);
|
||||
auto* temp_data = temp_tensor.mutable_data<float>(cpu_place);
|
||||
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
*(temp_data + i) = random(0., 1.);
|
||||
}
|
||||
|
||||
TensorCopySync(temp_tensor, place, tensor);
|
||||
}
|
||||
|
||||
/*
|
||||
* Help to validate the correctness between Fluid Op and the corresponding
|
||||
* anakin
|
||||
* layer.
|
||||
*/
|
||||
class AnakinConvertValidation {
|
||||
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
|
||||
|
||||
public:
|
||||
AnakinConvertValidation() = delete;
|
||||
|
||||
AnakinConvertValidation(const std::unordered_set<std::string>& parameters,
|
||||
const framework::Scope& scope)
|
||||
: parameters_(parameters), scope_(scope), place_(0) {
|
||||
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
|
||||
engine_.reset(new AnakinEngine<NV, Precision::FP32>(true));
|
||||
}
|
||||
|
||||
// Declare a Variable as input with random initialization.
|
||||
void DeclInputVar(const std::string& name,
|
||||
const std::vector<int> tensor_dims) {
|
||||
DeclVar(name, tensor_dims);
|
||||
// should decalre anakin input here.
|
||||
}
|
||||
|
||||
void DeclParamVar(const std::string& name, const std::vector<int> dim_vec) {
|
||||
DeclVar(name, dim_vec);
|
||||
}
|
||||
|
||||
void DeclOutputVar(const std::string& name, const std::vector<int> dim_vec) {
|
||||
DeclVar(name, dim_vec);
|
||||
// should declare anakin output here.
|
||||
}
|
||||
|
||||
void DeclVar(const std::string& name, const std::vector<int> dim_vec) {
|
||||
platform::CUDADeviceContext ctx(place_);
|
||||
auto* x = scope_.Var(name);
|
||||
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
|
||||
x_tensor->Resize(framework::make_ddim(dim_vec));
|
||||
RandomizeTensor(x_tensor, place_, ctx);
|
||||
}
|
||||
|
||||
void SetOp(const framework::proto::OpDesc& desc) {
|
||||
op_ = framework::OpRegistry::CreateOp(desc);
|
||||
op_desc_.reset(new framework::OpDesc(desc, nullptr));
|
||||
// should init anakin engine here.
|
||||
|
||||
Singleton<AnakinOpConverter>::Global().ConvertOp(
|
||||
desc, parameters_, scope_, engine_.get(), true /*test_mode*/);
|
||||
engine_->Freeze();
|
||||
for (const auto& input : op_desc_->InputArgumentNames()) {
|
||||
if (parameters_.count(input)) continue;
|
||||
auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(scope_,
|
||||
input);
|
||||
auto t_shape = framework::vectorize2int(t.dims());
|
||||
engine_->SetInputShape(input, t_shape);
|
||||
}
|
||||
engine_->Optimize();
|
||||
}
|
||||
|
||||
// We use the set 'neglected_output' here, because some Ops like batch norm,
|
||||
// the outputs specified in the op des are only used during training,
|
||||
// so we should neglect those output during inference.
|
||||
void Execute(int batch_size,
|
||||
std::unordered_set<std::string> neglected_output = {}) {
|
||||
// Execute Fluid Op
|
||||
platform::CUDADeviceContext ctx(place_);
|
||||
op_->Run(scope_, place_);
|
||||
|
||||
for (const auto& output : op_desc_->OutputArgumentNames()) {
|
||||
if (neglected_output.count(output)) continue;
|
||||
std::vector<float> fluid_out;
|
||||
auto* var = scope_.FindVar(output);
|
||||
auto* tensor = var->GetMutable<framework::LoDTensor>();
|
||||
framework::TensorToVector(*tensor, ctx, &fluid_out);
|
||||
|
||||
size_t fluid_out_size = fluid_out.size();
|
||||
for (size_t i = 0; i < fluid_out_size; i++) {
|
||||
std::cout << fluid_out[i] << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
framework::Scope& scope() { return scope_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<AnakinNvEngineT> engine_{nullptr};
|
||||
cudaStream_t stream_;
|
||||
std::unique_ptr<framework::OperatorBase> op_;
|
||||
std::unique_ptr<framework::OpDesc> op_desc_;
|
||||
const std::unordered_set<std::string>& parameters_;
|
||||
framework::Scope& scope_;
|
||||
platform::CUDAPlace place_;
|
||||
};
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,112 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/anakin/engine.h"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
|
||||
using anakin::Precision;
|
||||
using anakin::OpRunType;
|
||||
using paddle::framework::LoDTensor;
|
||||
template <typename T, Precision P, OpRunType O>
|
||||
using AnakinNetT = anakin::Net<T, P, O>;
|
||||
|
||||
template <typename T, Precision P>
|
||||
using AnakinGraphT = anakin::graph::Graph<T, P>;
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
AnakinEngine<TargetT, PrecisionType, RunType>::AnakinEngine(bool need_summary)
|
||||
: graph_(new AnakinGraphT<TargetT, PrecisionType>()),
|
||||
net_(new AnakinNetT<TargetT, PrecisionType, RunType>(need_summary)) {}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
AnakinEngine<TargetT, PrecisionType, RunType>::~AnakinEngine() {}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
void AnakinEngine<TargetT, PrecisionType, RunType>::SetInputShape(
|
||||
const std::string &name, std::vector<int> shape) {
|
||||
graph_->AddOpAttr<::anakin::PTuple<int>>(name, "input_shape",
|
||||
std::move(shape));
|
||||
}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
void AnakinEngine<TargetT, PrecisionType, RunType>::InitGraph() {
|
||||
net_->init(*graph_);
|
||||
}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
void AnakinEngine<TargetT, PrecisionType, RunType>::AddOp(
|
||||
const std::string &name, const std::string &type,
|
||||
const std::vector<std::string> &inputs,
|
||||
const std::vector<std::string> &outputs) {
|
||||
PADDLE_ENFORCE(graph_->AddOp(name, type, inputs, outputs), "Add operation.");
|
||||
}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
|
||||
const std::map<std::string, framework::LoDTensor *> &inputs,
|
||||
const std::map<std::string, framework::LoDTensor *> &outputs) {
|
||||
for (const auto &input : inputs) {
|
||||
auto *tensor = input.second;
|
||||
auto *data = tensor->data<float>();
|
||||
auto shape = framework::vectorize2int(tensor->dims());
|
||||
::anakin::saber::Shape anakin_shape(shape);
|
||||
auto *anakin_input = net_->get_in(input.first);
|
||||
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
|
||||
anakin_shape);
|
||||
anakin_input->share_from(tmp_anakin_tensor);
|
||||
}
|
||||
|
||||
for (const auto &output : outputs) {
|
||||
auto *tensor = output.second;
|
||||
auto *data = tensor->data<float>();
|
||||
auto shape = framework::vectorize2int(tensor->dims());
|
||||
::anakin::saber::Shape anakin_shape(shape);
|
||||
auto *anakin_output = net_->get_out(output.first);
|
||||
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
|
||||
anakin_shape);
|
||||
anakin_output->share_from(tmp_anakin_tensor);
|
||||
}
|
||||
net_->prediction();
|
||||
}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
void AnakinEngine<TargetT, PrecisionType, RunType>::Freeze() {
|
||||
PADDLE_ENFORCE(graph_->Freeze(), "Freeze anakin subgraph.");
|
||||
}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
void AnakinEngine<TargetT, PrecisionType, RunType>::Optimize() {
|
||||
PADDLE_ENFORCE(graph_->Optimize(), "Graph optimization.");
|
||||
}
|
||||
|
||||
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
|
||||
std::unique_ptr<AnakinEngine<TargetT, PrecisionType, RunType>>
|
||||
AnakinEngine<TargetT, PrecisionType, RunType>::Clone() {
|
||||
auto *engine = new AnakinEngine();
|
||||
engine->net_ = std::move(net_->Clone());
|
||||
return std::unique_ptr<AnakinEngine>(engine);
|
||||
}
|
||||
|
||||
template class AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>;
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/inference/engine.h"
|
||||
#include "paddle/fluid/inference/utils/singleton.h"
|
||||
|
||||
#include "framework/core/net/net.h"
|
||||
#include "framework/core/types.h"
|
||||
#include "framework/graph/graph.h"
|
||||
#include "saber/saber_types.h"
|
||||
|
||||
namespace anakin {
|
||||
|
||||
template <typename, Precision, OpRunType>
|
||||
class Net;
|
||||
|
||||
namespace graph {
|
||||
template <typename, Precision>
|
||||
class Graph;
|
||||
} // namespace graph
|
||||
} // namespace anakin
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
template <typename TargetT, ::anakin::Precision PrecisionType,
|
||||
::anakin::OpRunType RunType = ::anakin::OpRunType::ASYNC>
|
||||
class AnakinEngine {
|
||||
public:
|
||||
explicit AnakinEngine(bool need_summary = false);
|
||||
~AnakinEngine();
|
||||
void InitGraph();
|
||||
void SetInputShape(const std::string &name, std::vector<int> shape);
|
||||
void AddOp(const std::string &name, const std::string &type,
|
||||
const std::vector<std::string> &inputs,
|
||||
const std::vector<std::string> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void AddOpAttr(const std::string &op_name, const std::string &attr_name,
|
||||
const T &attr_value) {
|
||||
PADDLE_ENFORCE(graph_->AddOpAttr(op_name, attr_name, attr_value),
|
||||
"Add operation's attribution.");
|
||||
}
|
||||
|
||||
std::unique_ptr<AnakinEngine> Clone();
|
||||
void Freeze();
|
||||
void Optimize();
|
||||
void Execute(const std::map<std::string, framework::LoDTensor *> &inputs,
|
||||
const std::map<std::string, framework::LoDTensor *> &outputs);
|
||||
|
||||
private:
|
||||
using NetT = ::anakin::Net<TargetT, PrecisionType, RunType>;
|
||||
using GraphT = ::anakin::graph::Graph<TargetT, PrecisionType>;
|
||||
std::unique_ptr<GraphT> graph_;
|
||||
std::unique_ptr<NetT> net_;
|
||||
};
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,92 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "framework/core/net/net.h"
|
||||
#include "framework/graph/graph.h"
|
||||
#include "framework/graph/graph_global_mem.h"
|
||||
#include "paddle/fluid/inference/anakin/engine.h"
|
||||
|
||||
using anakin::graph::GraphGlobalMem;
|
||||
using anakin::AK_FLOAT;
|
||||
using anakin::Precision;
|
||||
using anakin::saber::NV;
|
||||
using anakin::saber::X86;
|
||||
using anakin::saber::Shape;
|
||||
using anakin::PBlock;
|
||||
using anakin::PTuple;
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
class TestAnakinEngine : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
void TearDown() override {}
|
||||
|
||||
protected:
|
||||
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
|
||||
std::unique_ptr<AnakinNvEngineT> engine_{nullptr};
|
||||
};
|
||||
|
||||
void TestAnakinEngine::SetUp() {
|
||||
engine_.reset(new AnakinEngine<NV, Precision::FP32>(true));
|
||||
|
||||
TEST_F(TestAnakinEngine, Execute) {
|
||||
engine_->AddOp("op1", "Dense", {"x"}, {"y"});
|
||||
engine_->AddOpAttr("op1", "out_dim", 2);
|
||||
engine_->AddOpAttr("op1", "bias_term", false);
|
||||
engine_->AddOpAttr("op1", "axis", 1);
|
||||
std::vector<int> shape = {1, 1, 1, 2};
|
||||
Shape tmp_shape(shape);
|
||||
auto *weight1 =
|
||||
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(tmp_shape);
|
||||
|
||||
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
|
||||
cpu_data[0] = 2.;
|
||||
weight1->d_tensor().set_shape(tmp_shape);
|
||||
weight1->d_tensor().copy_from(weight1->h_tensor());
|
||||
engine_->AddOpAttr("op1", "weight_1", *weight1);
|
||||
|
||||
engine_->Freeze();
|
||||
engine_->SetInputShape("x", {1, 1, 1, 1});
|
||||
engine_->Optimize();
|
||||
engine_->InitGraph();
|
||||
framework::LoDTensor x;
|
||||
framework::LoDTensor y;
|
||||
x.Resize({1, 1, 1, 1});
|
||||
y.Resize({1, 1, 1, 2});
|
||||
auto *x_data = x.mutable_data<float>(platform::CUDAPlace());
|
||||
float x_data_cpu[] = {1.};
|
||||
cudaMemcpy(x_data, x_data_cpu, sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
std::map<std::string, framework::LoDTensor *> inputs = {{"x", &x}};
|
||||
auto *y_data = y.mutable_data<float>(platform::CUDAPlace());
|
||||
std::map<std::string, framework::LoDTensor *> outputs = {{"y", &y}};
|
||||
|
||||
engine_->Execute(inputs, outputs);
|
||||
auto *y_data_gpu = y_data;
|
||||
float y_data_cpu[2];
|
||||
cudaMemcpy(y_data_cpu, y_data_gpu, sizeof(float) * 2,
|
||||
cudaMemcpyDeviceToHost);
|
||||
LOG(INFO) << "output value: " << y_data_cpu[0] << ", " << y_data_cpu[1];
|
||||
}
|
||||
}
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue