Compare commits
125 Commits
Author | SHA1 | Date |
---|---|---|
|
bc7a3afa68 | 4 years ago |
|
a6343afc70 | 4 years ago |
|
3ab39705ea | 4 years ago |
|
ac89174e5a | 4 years ago |
|
3c66b8721a | 4 years ago |
|
4a823c5f63 | 4 years ago |
|
f354e1d6d5 | 4 years ago |
|
9754d0a7cb | 4 years ago |
|
48aa92234d | 4 years ago |
|
a93488839d | 4 years ago |
|
0279486b02 | 4 years ago |
|
b2407af6e3 | 4 years ago |
|
d1a4c53eee | 4 years ago |
|
fead563156 | 4 years ago |
|
149f76e636 | 4 years ago |
|
b3446670c1 | 4 years ago |
|
03803f20fd | 4 years ago |
|
6350528220 | 4 years ago |
|
c594f57685 | 4 years ago |
|
228bce12c8 | 4 years ago |
|
594bbcb189 | 4 years ago |
|
fba994c28b | 4 years ago |
|
02912ce7f2 | 4 years ago |
|
e6af7c0dd8 | 4 years ago |
|
17862b725f | 4 years ago |
|
342252c902 | 4 years ago |
|
7b450e7889 | 4 years ago |
|
50bc11621f | 4 years ago |
|
c8729f2aec | 4 years ago |
|
f8e1f452c4 | 4 years ago |
|
faf40da585 | 4 years ago |
|
d55120d77f | 4 years ago |
|
e424712073 | 4 years ago |
|
7ec8459c6c | 4 years ago |
|
7875bcb8f7 | 4 years ago |
|
125201ee56 | 4 years ago |
|
ef15544ee0 | 4 years ago |
|
743cc9b29b | 4 years ago |
|
1de6daff82 | 4 years ago |
|
3dd992e24f | 4 years ago |
|
444c285202 | 4 years ago |
|
8f08f160c6 | 4 years ago |
|
468ac6993b | 4 years ago |
|
5118968d80 | 4 years ago |
|
925432d85e | 4 years ago |
|
1e956001ec | 4 years ago |
|
795b0f92d3 | 4 years ago |
|
b541ca8795 | 4 years ago |
|
57220f594d | 4 years ago |
|
3ca4bc1004 | 4 years ago |
|
11f788771c | 4 years ago |
|
e3e15792a4 | 4 years ago |
|
a3cc4a4a69 | 4 years ago |
|
f250416029 | 4 years ago |
|
7241bc2210 | 4 years ago |
|
9606a86b18 | 4 years ago |
|
47860ce20d | 4 years ago |
|
de65486c19 | 4 years ago |
|
ec2160a622 | 4 years ago |
|
0234693040 | 4 years ago |
|
5e851bff42 | 4 years ago |
|
382fc31f89 | 4 years ago |
|
5d29a27c2e | 4 years ago |
|
09bf2cfc0e | 4 years ago |
|
f1fdddfdc8 | 4 years ago |
|
e1c33a6d69 | 4 years ago |
|
3bf8a34c69 | 4 years ago |
|
d746197398 | 4 years ago |
|
5d22e15b6e | 4 years ago |
|
581e5460a0 | 4 years ago |
|
cfeeb4bc95 | 4 years ago |
|
e15ccafb84 | 4 years ago |
|
29d50d2049 | 4 years ago |
|
f400ce9f51 | 4 years ago |
|
7524ac9345 | 4 years ago |
|
3f206e97c4 | 4 years ago |
|
9df84bd693 | 4 years ago |
|
e19195f795 | 4 years ago |
|
15823bb0df | 4 years ago |
|
388c69f27d | 4 years ago |
|
c956c035dc | 4 years ago |
|
83f81eb573 | 4 years ago |
|
5fe3d596e4 | 4 years ago |
|
ecc6e213d7 | 4 years ago |
|
b3c88e961c | 4 years ago |
|
ac3d821bc0 | 4 years ago |
|
0310945f5c | 4 years ago |
|
45765d6eb6 | 4 years ago |
|
8497e2aad3 | 4 years ago |
|
9fcdaeba5e | 4 years ago |
|
5618f14047 | 4 years ago |
|
a1ddff81e3 | 4 years ago |
|
d23bf89cf6 | 4 years ago |
|
77a0c41cb2 | 4 years ago |
|
187248f568 | 4 years ago |
|
821c2f4ef8 | 4 years ago |
|
d45f5d787e | 4 years ago |
|
387c1db4f1 | 4 years ago |
|
ff4654e216 | 4 years ago |
|
1435b4c096 | 4 years ago |
|
678a3e8fed | 4 years ago |
|
85cbd55648 | 4 years ago |
|
5cb20f30fc | 4 years ago |
|
c687edecd8 | 4 years ago |
|
a6edbc478b | 4 years ago |
|
1201cd2ef2 | 4 years ago |
|
7e049108c5 | 4 years ago |
|
81138239db | 4 years ago |
|
ebef6601d5 | 4 years ago |
|
500f28ec37 | 4 years ago |
|
de42d19336 | 4 years ago |
|
ebb5d181e8 | 4 years ago |
|
4a26729540 | 4 years ago |
|
636fefd9f8 | 4 years ago |
|
88dfd067bf | 4 years ago |
|
6eabbc8076 | 4 years ago |
|
904cc44349 | 4 years ago |
|
5b77b259d8 | 4 years ago |
|
7158061a29 | 4 years ago |
|
e4287ca60b | 4 years ago |
|
f5aca8fbb4 | 4 years ago |
|
d2404da768 | 4 years ago |
|
f9c97dd728 | 4 years ago |
|
1882f2ce2d | 4 years ago |
|
6dd52c5b25 | 4 years ago |
@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
#NOTE: Logic is from
|
||||
# https://github.com/mindspore-ai/graphengine/blob/master/CMakeLists.txt
|
||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH})
|
||||
else()
|
||||
set(ASCEND_DIR /usr/local/Ascend)
|
||||
endif()
|
||||
|
||||
if(WITH_ASCEND)
|
||||
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
|
||||
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
|
||||
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share)
|
||||
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64)
|
||||
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
|
||||
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
|
||||
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})
|
||||
|
||||
set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR})
|
||||
set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR})
|
||||
set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
|
||||
set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64)
|
||||
set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64)
|
||||
set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR})
|
||||
|
||||
set(atlas_graph_lib ${ATLAS_RUNTIME_DIR}/libgraph.so)
|
||||
set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so)
|
||||
set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so)
|
||||
INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})
|
||||
|
||||
if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h)
|
||||
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
|
||||
endif()
|
||||
|
||||
ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib})
|
||||
|
||||
ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph_lib})
|
||||
|
||||
ADD_LIBRARY(atlas_acl SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET atlas_acl PROPERTY IMPORTED_LOCATION ${atlas_acl_lib})
|
||||
|
||||
add_custom_target(extern_ascend DEPENDS ascend_ge ascend_graph atlas_acl)
|
||||
endif()
|
||||
|
||||
if(WITH_ASCEND_CL)
|
||||
set(ASCEND_CL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
|
||||
set(ascend_hccl_lib ${ASCEND_CL_DIR}/libhccl.so)
|
||||
set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so)
|
||||
set(acl_op_compiler_lib ${ASCEND_CL_DIR}/libacl_op_compiler.so)
|
||||
set(ASCEND_CL_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
|
||||
|
||||
message(STATUS "ASCEND_CL_INC_DIR ${ASCEND_CL_INC_DIR}")
|
||||
message(STATUS "ASCEND_CL_DIR ${ASCEND_CL_DIR}")
|
||||
INCLUDE_DIRECTORIES(${ASCEND_CL_INC_DIR})
|
||||
|
||||
ADD_LIBRARY(ascendcl SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ascendcl PROPERTY IMPORTED_LOCATION ${ascendcl_lib})
|
||||
|
||||
ADD_LIBRARY(ascend_hccl SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ascend_hccl PROPERTY IMPORTED_LOCATION ${ascend_hccl_lib})
|
||||
|
||||
ADD_LIBRARY(acl_op_compiler SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET acl_op_compiler PROPERTY IMPORTED_LOCATION ${acl_op_compiler_lib})
|
||||
add_custom_target(extern_ascend_cl DEPENDS ascendcl acl_op_compiler)
|
||||
|
||||
endif()
|
@ -0,0 +1,22 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifdef PADDLE_WITH_ASCEND
|
||||
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
std::shared_ptr<AscendInstance> AscendInstance::ascend_instance_ = nullptr;
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
||||
#endif
|
@ -0,0 +1,208 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef PADDLE_WITH_ASCEND
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/platform/timer.h"
|
||||
|
||||
#include "ge/ge_api.h"
|
||||
#include "ge/ge_api_types.h"
|
||||
#include "graph/attr_value.h"
|
||||
#include "graph/tensor.h"
|
||||
#include "graph/types.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
typedef ge::Graph AscendGraphDesc;
|
||||
|
||||
#ifdef PADDLE_WITH_ASCEND_STRING
|
||||
using AscendString = ge::AscendString;
|
||||
#else
|
||||
using AscendString = std::string;
|
||||
#endif
|
||||
|
||||
class AscendInstance {
|
||||
public:
|
||||
virtual ~AscendInstance() {}
|
||||
AscendInstance() {}
|
||||
|
||||
std::map<AscendString, AscendString> _GetDefaultInitOptions() {
|
||||
std::map<AscendString, AscendString> init_options;
|
||||
init_options["ge.exec.deviceId"] = "0";
|
||||
init_options["ge.graphRunMode"] = "1";
|
||||
return init_options;
|
||||
}
|
||||
|
||||
std::map<AscendString, AscendString> _GetDefaultInitSessionOptions() {
|
||||
std::map<AscendString, AscendString> init_options;
|
||||
// init_options["a"] = "b";
|
||||
// init_options["ge.trainFlag"] = "1";
|
||||
return init_options;
|
||||
}
|
||||
|
||||
ge::Status InitGEForUT() {
|
||||
return ge::GEInitialize(_GetDefaultInitOptions());
|
||||
}
|
||||
|
||||
void InitGlobalResouces() {
|
||||
LOG(INFO) << "Begin ascend InitGlobalResouces";
|
||||
session_.reset(new ge::Session(_GetDefaultInitSessionOptions()));
|
||||
if (session_ == nullptr) {
|
||||
LOG(FATAL) << "new session error:" << session_;
|
||||
}
|
||||
LOG(INFO) << "End ascend InitGlobalResouces";
|
||||
}
|
||||
|
||||
void DestroyGlobalResouces() {
|
||||
LOG(INFO) << "Begin ascend DestroyGlobalResouces";
|
||||
session_ = nullptr;
|
||||
LOG(INFO) << "Begin ascend DestroyGlobalResouces";
|
||||
}
|
||||
|
||||
static std::shared_ptr<AscendInstance> GetInstance() {
|
||||
if (nullptr == ascend_instance_) {
|
||||
ascend_instance_.reset(new paddle::framework::AscendInstance());
|
||||
VLOG(1) << "Initialize AscendInstance Done";
|
||||
}
|
||||
return ascend_instance_;
|
||||
}
|
||||
|
||||
void AddAscendSubgraph(int graph_idx, const AscendGraphDesc &graph) {
|
||||
ge::Status status = session_->AddGraph(graph_idx, graph);
|
||||
PADDLE_ENFORCE_EQ(status, ge::SUCCESS,
|
||||
paddle::platform::errors::PreconditionNotMet(
|
||||
"Calling addGraph of graph engine failed, please "
|
||||
"check Ascend Log."));
|
||||
VLOG(1) << "AddAscendSubgraph " << graph_idx << " Done";
|
||||
}
|
||||
|
||||
ge::DataType VarTypeToGeType(proto::VarType::Type type) {
|
||||
if (type == proto::VarType::FP16) {
|
||||
return ge::DataType::DT_FLOAT16;
|
||||
} else if (type == proto::VarType::FP32) {
|
||||
return ge::DataType::DT_FLOAT;
|
||||
} else if (type == proto::VarType::FP64) {
|
||||
return ge::DataType::DT_DOUBLE;
|
||||
} else if (type == proto::VarType::INT32) {
|
||||
return ge::DataType::DT_INT32;
|
||||
} else if (type == proto::VarType::INT64) {
|
||||
return ge::DataType::DT_INT64;
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::Unimplemented(
|
||||
"Not support %s as tensor type.", DataTypeToString(type)));
|
||||
}
|
||||
}
|
||||
int GeTypeSize(proto::VarType::Type type) {
|
||||
if (type == proto::VarType::FP16) {
|
||||
return 2;
|
||||
} else if (type == proto::VarType::FP32) {
|
||||
return 4;
|
||||
} else if (type == proto::VarType::FP64) {
|
||||
return 8;
|
||||
} else if (type == proto::VarType::INT32) {
|
||||
return 4;
|
||||
} else if (type == proto::VarType::INT64) {
|
||||
return 8;
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::Unimplemented(
|
||||
"Not support %s as tensor type.", DataTypeToString(type)));
|
||||
}
|
||||
}
|
||||
ge::Tensor ConvertToGeTensor(const Tensor *tensor) {
|
||||
auto numel = tensor->numel();
|
||||
std::vector<int64_t> vec_dim;
|
||||
auto dimen = arity(tensor->dims());
|
||||
for (auto i = 0; i < dimen; ++i) {
|
||||
vec_dim.push_back(tensor->dims()[i]);
|
||||
}
|
||||
// For Debug
|
||||
// VLOG(1) << "input numel: " << numel << ", dimen is " << vec_dim.size() <<
|
||||
// ", and shape is";
|
||||
// for (const auto e : vec_dim) {
|
||||
// VLOG(0) << e;
|
||||
// }
|
||||
|
||||
ge::Shape shape(vec_dim);
|
||||
ge::TensorDesc tensor_desc(shape, ge::Format::FORMAT_ND,
|
||||
VarTypeToGeType(tensor->type()));
|
||||
tensor_desc.SetRealDimCnt(vec_dim.size());
|
||||
|
||||
const uint8_t *data =
|
||||
reinterpret_cast<const uint8_t *>(tensor->data<void>());
|
||||
std::vector<uint8_t> dst(numel * GeTypeSize(tensor->type()));
|
||||
memcpy(dst.data(), data, GeTypeSize(tensor->type()) * numel);
|
||||
ge::Tensor ge_tensor(tensor_desc, dst);
|
||||
return ge_tensor;
|
||||
}
|
||||
|
||||
void RunAscendSubgraph(int graph_idx,
|
||||
const std::vector<const Tensor *> &inputs,
|
||||
std::vector<Tensor *> *outputs) {
|
||||
VLOG(1) << "Ascend Graph[" << graph_idx << "] is about to run.";
|
||||
// Convert paddle Tensor to GE Tensor
|
||||
std::vector<ge::Tensor> ge_inputs;
|
||||
for (const auto &e : inputs) {
|
||||
ge_inputs.push_back(ConvertToGeTensor(e));
|
||||
}
|
||||
|
||||
// Run Graph
|
||||
std::vector<ge::Tensor> ge_outputs;
|
||||
ge::Status status = session_->RunGraph(graph_idx, ge_inputs, ge_outputs);
|
||||
PADDLE_ENFORCE_EQ(status, ge::SUCCESS,
|
||||
paddle::platform::errors::PreconditionNotMet(
|
||||
"Calling RunGraph of graph engine failed, please "
|
||||
"check Ascend Log."));
|
||||
VLOG(1) << "Run Ascend Graph[" << graph_idx << "] Done";
|
||||
|
||||
// change tensor back, note all tensor's type computed in GE is uint8
|
||||
for (size_t i = 0; i < ge_outputs.size(); ++i) {
|
||||
const uint8_t *ret_data = ge_outputs[i].GetData();
|
||||
size_t size = ge_outputs[i].GetSize();
|
||||
VLOG(1) << "GE Tensor size of the " << i << "th output var is " << size;
|
||||
auto *dst = (*outputs)[i]->mutable_data<uint8_t>({(int64_t)size},
|
||||
platform::CPUPlace());
|
||||
memcpy(dst, ret_data, size);
|
||||
|
||||
// Following for debug:
|
||||
// VLOG(0) << "output for " << i << " var: ";
|
||||
// float *tmp = reinterpret_cast<float*>(dst);
|
||||
// for (size_t j = 0; j < size / 4; ++j) {
|
||||
// printf("%f ", tmp[j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ge::Session> session_;
|
||||
|
||||
private:
|
||||
static std::shared_ptr<AscendInstance> ascend_instance_;
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue