Ascend Framework Part1: OP & Wrapper (#30281)
parent
34bf8dfc40
commit
40ede12631
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(ASCEND_PROJECT "extern_ascend")
|
||||
IF((NOT DEFINED ASCEND_VER) OR (NOT DEFINED ASCEND_URL))
|
||||
MESSAGE(STATUS "use pre defined download url")
|
||||
SET(ASCEND_VER "0.1.1" CACHE STRING "" FORCE)
|
||||
SET(ASCEND_NAME "ascend" CACHE STRING "" FORCE)
|
||||
SET(ASCEND_URL "http://paddle-ascend.bj.bcebos.com/ascend.tar.gz" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "ASCEND_NAME: ${ASCEND_NAME}, ASCEND_URL: ${ASCEND_URL}")
|
||||
SET(ASCEND_SOURCE_DIR "${THIRD_PARTY_PATH}/ascend")
|
||||
SET(ASCEND_DOWNLOAD_DIR "${ASCEND_SOURCE_DIR}/src/${ASCEND_PROJECT}")
|
||||
SET(ASCEND_DST_DIR "ascend")
|
||||
SET(ASCEND_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
|
||||
SET(ASCEND_INSTALL_DIR ${ASCEND_INSTALL_ROOT}/${ASCEND_DST_DIR})
|
||||
SET(ASCEND_ROOT ${ASCEND_INSTALL_DIR})
|
||||
SET(ASCEND_INC_DIR ${ASCEND_ROOT}/include)
|
||||
SET(ASCEND_LIB_DIR ${ASCEND_ROOT}/lib)
|
||||
SET(ASCEND_LIB ${ASCEND_LIB_DIR}/libge_runner.so)
|
||||
SET(ASCEND_GRAPH_LIB ${ASCEND_LIB_DIR}/libgraph.so)
|
||||
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${ASCEND_ROOT}/lib")
|
||||
|
||||
INCLUDE_DIRECTORIES(${ASCEND_INC_DIR})
|
||||
FILE(WRITE ${ASCEND_DOWNLOAD_DIR}/CMakeLists.txt
|
||||
"PROJECT(ASCEND)\n"
|
||||
"cmake_minimum_required(VERSION 3.0)\n"
|
||||
"install(DIRECTORY ${ASCEND_NAME}/include ${ASCEND_NAME}/lib \n"
|
||||
" DESTINATION ${ASCEND_DST_DIR})\n")
|
||||
ExternalProject_Add(
|
||||
${ASCEND_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${ASCEND_SOURCE_DIR}
|
||||
DOWNLOAD_DIR ${ASCEND_DOWNLOAD_DIR}
|
||||
DOWNLOAD_COMMAND wget --no-check-certificate ${ASCEND_URL} -c -q -O ${ASCEND_NAME}.tar.gz
|
||||
&& tar zxvf ${ASCEND_NAME}.tar.gz
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${ASCEND_INSTALL_ROOT}
|
||||
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ASCEND_INSTALL_ROOT}
|
||||
)
|
||||
ADD_LIBRARY(ascend SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ascend PROPERTY IMPORTED_LOCATION ${ASCEND_LIB})
|
||||
|
||||
ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${ASCEND_GRAPH_LIB})
|
||||
ADD_DEPENDENCIES(ascend ascend_graph ${ASCEND_PROJECT})
|
||||
|
@ -0,0 +1,22 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifdef PADDLE_WITH_ASCEND
|
||||
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
std::shared_ptr<AscendInstance> AscendInstance::ascend_instance_ = nullptr;
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
||||
#endif
|
@ -0,0 +1,183 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef PADDLE_WITH_ASCEND
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/platform/timer.h"
|
||||
|
||||
#include "ge/ge_api.h"
|
||||
#include "ge/ge_api_types.h"
|
||||
#include "graph/attr_value.h"
|
||||
#include "graph/tensor.h"
|
||||
#include "graph/types.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// typedef std::vector<std::string> AscendGraphDesc;
|
||||
typedef ge::Graph AscendGraphDesc;
|
||||
|
||||
class AscendInstance {
|
||||
public:
|
||||
virtual ~AscendInstance() {}
|
||||
AscendInstance() {}
|
||||
|
||||
std::map<std::string, std::string> GetDefaultInitSessionOptions() {
|
||||
std::map<std::string, std::string> init_options;
|
||||
init_options["a"] = "b";
|
||||
init_options["ge.trainFlag"] = "1";
|
||||
return init_options;
|
||||
}
|
||||
|
||||
// add other parameters here to init
|
||||
void InitGlobalResouces() {
|
||||
session_.reset(new ge::Session(GetDefaultInitSessionOptions()));
|
||||
VLOG(1) << "InitGlobalResouces Done";
|
||||
}
|
||||
|
||||
static std::shared_ptr<AscendInstance> GetInstance() {
|
||||
if (nullptr == ascend_instance_) {
|
||||
ascend_instance_.reset(new paddle::framework::AscendInstance());
|
||||
VLOG(1) << "Initialize AscendInstance Done";
|
||||
}
|
||||
return ascend_instance_;
|
||||
}
|
||||
|
||||
void AddAscendSubgraph(int graph_idx, const AscendGraphDesc &graph) {
|
||||
ge::Status status = session_->AddGraph(graph_idx, graph);
|
||||
PADDLE_ENFORCE_EQ(status, ge::SUCCESS,
|
||||
paddle::platform::errors::PreconditionNotMet(
|
||||
"Calling addGraph of graph engine failed, please "
|
||||
"check Ascend Log."));
|
||||
VLOG(1) << "AddAscendSubgraph " << graph_idx << " Done";
|
||||
}
|
||||
|
||||
ge::DataType VarTypeToGeType(proto::VarType::Type type) {
|
||||
if (type == proto::VarType::FP16) {
|
||||
return ge::DataType::DT_FLOAT16;
|
||||
} else if (type == proto::VarType::FP32) {
|
||||
return ge::DataType::DT_FLOAT;
|
||||
} else if (type == proto::VarType::FP64) {
|
||||
return ge::DataType::DT_DOUBLE;
|
||||
} else if (type == proto::VarType::INT32) {
|
||||
return ge::DataType::DT_INT32;
|
||||
} else if (type == proto::VarType::INT64) {
|
||||
return ge::DataType::DT_INT64;
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::Unimplemented(
|
||||
"Not support %s as tensor type.", DataTypeToString(type)));
|
||||
}
|
||||
}
|
||||
int GeTypeSize(proto::VarType::Type type) {
|
||||
if (type == proto::VarType::FP16) {
|
||||
return 2;
|
||||
} else if (type == proto::VarType::FP32) {
|
||||
return 4;
|
||||
} else if (type == proto::VarType::FP64) {
|
||||
return 8;
|
||||
} else if (type == proto::VarType::INT32) {
|
||||
return 4;
|
||||
} else if (type == proto::VarType::INT64) {
|
||||
return 8;
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::Unimplemented(
|
||||
"Not support %s as tensor type.", DataTypeToString(type)));
|
||||
}
|
||||
}
|
||||
ge::Tensor ConvertToGeTensor(const Tensor *tensor) {
|
||||
auto numel = tensor->numel();
|
||||
std::vector<int64_t> vec_dim;
|
||||
auto dimen = arity(tensor->dims());
|
||||
for (auto i = 0; i < dimen; ++i) {
|
||||
vec_dim.push_back(tensor->dims()[i]);
|
||||
}
|
||||
// For Debug
|
||||
// VLOG(1) << "input numel: " << numel << ", dimen is " << vec_dim.size() <<
|
||||
// ", and shape is";
|
||||
// for (const auto e : vec_dim) {
|
||||
// VLOG(0) << e;
|
||||
// }
|
||||
|
||||
ge::Shape shape(vec_dim);
|
||||
ge::TensorDesc tensor_desc(shape, ge::Format::FORMAT_ND,
|
||||
VarTypeToGeType(tensor->type()));
|
||||
tensor_desc.SetRealDimCnt(vec_dim.size());
|
||||
|
||||
const uint8_t *data =
|
||||
reinterpret_cast<const uint8_t *>(tensor->data<void>());
|
||||
std::vector<uint8_t> dst(numel * GeTypeSize(tensor->type()));
|
||||
memcpy(dst.data(), data, GeTypeSize(tensor->type()) * numel);
|
||||
ge::Tensor ge_tensor(tensor_desc, dst);
|
||||
return ge_tensor;
|
||||
}
|
||||
|
||||
void RunAscendSubgraph(int graph_idx,
|
||||
const std::vector<const Tensor *> &inputs,
|
||||
std::vector<Tensor *> *outputs) {
|
||||
VLOG(1) << "Ascend Graph[" << graph_idx << "] is about to run.";
|
||||
// Convert paddle Tensor to GE Tensor
|
||||
std::vector<ge::Tensor> ge_inputs;
|
||||
for (const auto &e : inputs) {
|
||||
ge_inputs.push_back(ConvertToGeTensor(e));
|
||||
}
|
||||
|
||||
// Run Graph
|
||||
std::vector<ge::Tensor> ge_outputs;
|
||||
ge::Status status = session_->RunGraph(graph_idx, ge_inputs, ge_outputs);
|
||||
PADDLE_ENFORCE_EQ(status, ge::SUCCESS,
|
||||
paddle::platform::errors::PreconditionNotMet(
|
||||
"Calling RunGraph of graph engine failed, please "
|
||||
"check Ascend Log."));
|
||||
VLOG(1) << "Run Ascend Graph[" << graph_idx << "] Done";
|
||||
|
||||
// change tensor back, note all tensor's type computed in GE is uint8
|
||||
for (size_t i = 0; i < ge_outputs.size(); ++i) {
|
||||
const uint8_t *ret_data = ge_outputs[i].GetData();
|
||||
size_t size = ge_outputs[i].GetSize();
|
||||
VLOG(1) << "GE Tensor size of the " << i << "th output var is " << size;
|
||||
auto *dst = (*outputs)[i]->mutable_data<uint8_t>({(int64_t)size},
|
||||
platform::CPUPlace());
|
||||
memcpy(dst, ret_data, size);
|
||||
|
||||
// Following for debug:
|
||||
// VLOG(0) << "output for " << i << " var: ";
|
||||
// float *tmp = reinterpret_cast<float*>(dst);
|
||||
// for (size_t j = 0; j < size / 4; ++j) {
|
||||
// printf("%f ", tmp[j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ge::Session> session_;
|
||||
|
||||
private:
|
||||
static std::shared_ptr<AscendInstance> ascend_instance_;
|
||||
};
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
||||
#endif
|
@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/ascend_trigger_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class AscendTriggerOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(framework::proto::VarType::FP32,
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class AscendTriggerOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("FeedList", "FeedList of Ascend SubGraph").AsDuplicable();
|
||||
AddOutput("FetchList", "FetchList of Ascend SubGraph").AsDuplicable();
|
||||
AddAttr<int>("graph_idx", "(int, the graph index").SetDefault(-1);
|
||||
AddComment(R"DOC(
|
||||
Trigger Ascend SubGraph
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(ascend_trigger, ops::AscendTriggerOp,
|
||||
ops::AscendTriggerOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(ascend_trigger, ops::AscendTriggerCPUKernel<float>)
|
@ -0,0 +1,46 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#ifdef PADDLE_WITH_ASCEND
|
||||
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class AscendTriggerCPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
#ifdef PADDLE_WITH_ASCEND
|
||||
auto ascend_ptr = paddle::framework::AscendInstance::GetInstance();
|
||||
auto graph_idx = ctx.Attr<int>("graph_idx");
|
||||
VLOG(4) << "AscendTrigger Kernel, begin to run graph: " << graph_idx;
|
||||
auto inputs = ctx.MultiInput<framework::Tensor>("FeedList");
|
||||
auto outputs = ctx.MultiOutput<framework::Tensor>("FetchList");
|
||||
ascend_ptr->RunAscendSubgraph(graph_idx, inputs, &outputs);
|
||||
#else
|
||||
PADDLE_THROW(platform::errors::PreconditionNotMet(
|
||||
"Please compile WITH_ASCEND option to enable ascend_trigger op"));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import unittest
|
||||
|
||||
|
||||
class TestAscendTriggerOP(unittest.TestCase):
|
||||
""" TestCases for ascend_trigger op"""
|
||||
|
||||
def test_ascend_trigger_op(self):
|
||||
paddle.enable_static()
|
||||
program = fluid.Program()
|
||||
block = program.global_block()
|
||||
with fluid.program_guard(program):
|
||||
x = fluid.data(name='x', shape=[1], dtype='int64', lod_level=0)
|
||||
y = fluid.data(name='y', shape=[1], dtype='int64', lod_level=0)
|
||||
block.append_op(
|
||||
type="ascend_trigger",
|
||||
inputs={"FeedList": [x]},
|
||||
outputs={"FetchList": [y]},
|
||||
attrs={'graph_idx': 0})
|
||||
|
||||
exe = paddle.static.Executor(paddle.CPUPlace())
|
||||
try:
|
||||
exe.run(program)
|
||||
except RuntimeError as e:
|
||||
pass
|
||||
except:
|
||||
self.assertTrue(False)
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue