|
|
|
@ -159,6 +159,18 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
|
|
|
|
|
size_t size_backbone, const char *model_buf_head,
|
|
|
|
|
size_t size_head, lite::Context *context,
|
|
|
|
|
bool train_mode) {
|
|
|
|
|
auto ValidModelSize = [](size_t size) -> bool {
|
|
|
|
|
constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
|
|
|
|
|
return size < MaxModelSize && size > 0;
|
|
|
|
|
};
|
|
|
|
|
if (!ValidModelSize(size_backbone)) {
|
|
|
|
|
MS_LOG(ERROR) << "size_backbone too large: " << size_backbone;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (!ValidModelSize(size_head)) {
|
|
|
|
|
MS_LOG(ERROR) << "size_head too large: " << size_head;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context);
|
|
|
|
|
if (session == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "create transfer session failed";
|
|
|
|
|