You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/io/crypto/aes_cipher.cc

282 lines
11 KiB

// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/io/crypto/aes_cipher.h"
#include <cryptopp/aes.h>
#include <cryptopp/ccm.h>
#include <cryptopp/cryptlib.h>
#include <cryptopp/filters.h>
#include <cryptopp/gcm.h>
#include <cryptopp/modes.h>
#include <cryptopp/smartptr.h>
#include <set>
#include <string>
#include "paddle/fluid/framework/io/crypto/cipher_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
void AESCipher::Init(const std::string& cipher_name, const int& iv_size,
const int& tag_size) {
aes_cipher_name_ = cipher_name;
iv_size_ = iv_size;
tag_size_ = tag_size;
std::set<std::string> authented_cipher_set{"AES_GCM_NoPadding"};
if (authented_cipher_set.find(cipher_name) != authented_cipher_set.end()) {
is_authenticated_cipher_ = true;
}
}
std::string AESCipher::EncryptInternal(const std::string& plaintext,
const std::string& key) {
CryptoPP::member_ptr<CryptoPP::SymmetricCipher> m_cipher;
CryptoPP::member_ptr<CryptoPP::StreamTransformationFilter> m_filter;
bool need_iv = false;
const unsigned char* key_char =
reinterpret_cast<const unsigned char*>(&(key.at(0)));
BuildCipher(true, &need_iv, &m_cipher, &m_filter);
if (need_iv) {
iv_ = CipherUtils::GenKey(iv_size_);
m_cipher.get()->SetKeyWithIV(
key_char, key.size(),
reinterpret_cast<const unsigned char*>(&(iv_.at(0))), iv_.size());
} else {
m_cipher.get()->SetKey(key_char, key.size());
}
std::string ciphertext;
m_filter->Attach(new CryptoPP::StringSink(ciphertext));
CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter));
if (need_iv) {
return iv_ + ciphertext;
}
return ciphertext;
}
std::string AESCipher::DecryptInternal(const std::string& ciphertext,
const std::string& key) {
CryptoPP::member_ptr<CryptoPP::SymmetricCipher> m_cipher;
CryptoPP::member_ptr<CryptoPP::StreamTransformationFilter> m_filter;
bool need_iv = false;
const unsigned char* key_char =
reinterpret_cast<const unsigned char*>(&(key.at(0)));
BuildCipher(false, &need_iv, &m_cipher, &m_filter);
int ciphertext_beg = 0;
if (need_iv) {
iv_ = ciphertext.substr(0, iv_size_ / 8);
ciphertext_beg = iv_size_ / 8;
m_cipher.get()->SetKeyWithIV(
key_char, key.size(),
reinterpret_cast<const unsigned char*>(&(iv_.at(0))), iv_.size());
} else {
m_cipher.get()->SetKey(key_char, key.size());
}
std::string plaintext;
m_filter->Attach(new CryptoPP::StringSink(plaintext));
CryptoPP::StringSource(ciphertext.substr(ciphertext_beg), true,
new CryptoPP::Redirector(*m_filter));
return plaintext;
}
std::string AESCipher::AuthenticatedEncryptInternal(
const std::string& plaintext, const std::string& key) {
CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher> m_cipher;
CryptoPP::member_ptr<CryptoPP::AuthenticatedEncryptionFilter> m_filter;
bool need_iv = false;
const unsigned char* key_char =
reinterpret_cast<const unsigned char*>(&(key.at(0)));
BuildAuthEncCipher(&need_iv, &m_cipher, &m_filter);
if (need_iv) {
iv_ = CipherUtils::GenKey(iv_size_);
m_cipher.get()->SetKeyWithIV(
key_char, key.size(),
reinterpret_cast<const unsigned char*>(&(iv_.at(0))), iv_.size());
} else {
m_cipher.get()->SetKey(key_char, key.size());
}
std::string ciphertext;
m_filter->Attach(new CryptoPP::StringSink(ciphertext));
CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter));
if (need_iv) {
ciphertext = iv_.append(ciphertext);
}
return ciphertext;
}
std::string AESCipher::AuthenticatedDecryptInternal(
const std::string& ciphertext, const std::string& key) {
CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher> m_cipher;
CryptoPP::member_ptr<CryptoPP::AuthenticatedDecryptionFilter> m_filter;
bool need_iv = false;
const unsigned char* key_char =
reinterpret_cast<const unsigned char*>(&(key.at(0)));
BuildAuthDecCipher(&need_iv, &m_cipher, &m_filter);
int ciphertext_beg = 0;
if (need_iv) {
iv_ = ciphertext.substr(0, iv_size_ / 8);
ciphertext_beg = iv_size_ / 8;
m_cipher.get()->SetKeyWithIV(
key_char, key.size(),
reinterpret_cast<const unsigned char*>(&(iv_.at(0))), iv_.size());
} else {
m_cipher.get()->SetKey(key_char, key.size());
}
std::string plaintext;
m_filter->Attach(new CryptoPP::StringSink(plaintext));
CryptoPP::StringSource(ciphertext.substr(ciphertext_beg), true,
new CryptoPP::Redirector(*m_filter));
PADDLE_ENFORCE_EQ(
m_filter->GetLastResult(), true,
paddle::platform::errors::InvalidArgument("Integrity check failed. "
"Invalid ciphertext input."));
return plaintext;
}
void AESCipher::BuildCipher(
bool for_encrypt, bool* need_iv,
CryptoPP::member_ptr<CryptoPP::SymmetricCipher>* m_cipher,
CryptoPP::member_ptr<CryptoPP::StreamTransformationFilter>* m_filter) {
if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && for_encrypt) {
m_cipher->reset(new CryptoPP::ECB_Mode<CryptoPP::AES>::Encryption);
m_filter->reset(new CryptoPP::StreamTransformationFilter(
*(*m_cipher).get(), NULL,
CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
} else if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && !for_encrypt) {
m_cipher->reset(new CryptoPP::ECB_Mode<CryptoPP::AES>::Decryption);
m_filter->reset(new CryptoPP::StreamTransformationFilter(
*(*m_cipher).get(), NULL,
CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
} else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && for_encrypt) {
m_cipher->reset(new CryptoPP::CBC_Mode<CryptoPP::AES>::Encryption);
*need_iv = true;
m_filter->reset(new CryptoPP::StreamTransformationFilter(
*(*m_cipher).get(), NULL,
CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
} else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && !for_encrypt) {
m_cipher->reset(new CryptoPP::CBC_Mode<CryptoPP::AES>::Decryption);
*need_iv = true;
m_filter->reset(new CryptoPP::StreamTransformationFilter(
*(*m_cipher).get(), NULL,
CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
} else if (aes_cipher_name_ == "AES_CTR_NoPadding" && for_encrypt) {
m_cipher->reset(new CryptoPP::CTR_Mode<CryptoPP::AES>::Encryption);
*need_iv = true;
m_filter->reset(new CryptoPP::StreamTransformationFilter(
*(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
} else if (aes_cipher_name_ == "AES_CTR_NoPadding" && !for_encrypt) {
m_cipher->reset(new CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption);
*need_iv = true;
m_filter->reset(new CryptoPP::StreamTransformationFilter(
*(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Create cipher error. "
"Cipher name %s is error, or has not been implemented.",
aes_cipher_name_));
}
}
void AESCipher::BuildAuthEncCipher(
bool* need_iv,
CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher>* m_cipher,
CryptoPP::member_ptr<CryptoPP::AuthenticatedEncryptionFilter>* m_filter) {
if (aes_cipher_name_ == "AES_GCM_NoPadding") {
m_cipher->reset(new CryptoPP::GCM<CryptoPP::AES>::Encryption);
*need_iv = true;
m_filter->reset(new CryptoPP::AuthenticatedEncryptionFilter(
*(*m_cipher).get(), NULL, false, tag_size_ / 8,
CryptoPP::DEFAULT_CHANNEL,
CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Create cipher error. "
"Cipher name %s is error, or has not been implemented.",
aes_cipher_name_));
}
}
void AESCipher::BuildAuthDecCipher(
bool* need_iv,
CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher>* m_cipher,
CryptoPP::member_ptr<CryptoPP::AuthenticatedDecryptionFilter>* m_filter) {
if (aes_cipher_name_ == "AES_GCM_NoPadding") {
m_cipher->reset(new CryptoPP::GCM<CryptoPP::AES>::Decryption);
*need_iv = true;
m_filter->reset(new CryptoPP::AuthenticatedDecryptionFilter(
*(*m_cipher).get(), NULL,
CryptoPP::AuthenticatedDecryptionFilter::DEFAULT_FLAGS, tag_size_ / 8,
CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Create cipher error. "
"Cipher name %s is error, or has not been implemented.",
aes_cipher_name_));
}
}
std::string AESCipher::Encrypt(const std::string& plaintext,
const std::string& key) {
return is_authenticated_cipher_ ? AuthenticatedEncryptInternal(plaintext, key)
: EncryptInternal(plaintext, key);
}
std::string AESCipher::Decrypt(const std::string& ciphertext,
const std::string& key) {
return is_authenticated_cipher_
? AuthenticatedDecryptInternal(ciphertext, key)
: DecryptInternal(ciphertext, key);
}
void AESCipher::EncryptToFile(const std::string& plaintext,
const std::string& key,
const std::string& filename) {
std::ofstream fout(filename, std::ios::binary);
std::string ciphertext = this->Encrypt(plaintext, key);
fout.write(ciphertext.data(), ciphertext.size());
fout.close();
}
std::string AESCipher::DecryptFromFile(const std::string& key,
const std::string& filename) {
std::ifstream fin(filename, std::ios::binary);
std::string ciphertext{std::istreambuf_iterator<char>(fin),
std::istreambuf_iterator<char>()};
fin.close();
return Decrypt(ciphertext, key);
}
} // namespace framework
} // namespace paddle