{-# LANGUAGE DerivingStrategies #-}
module Temporal.Codec.Encryption (
Key,
keyFromBytes,
Cipher,
keyToBase64,
keyFromBase64,
genSecretKey,
initCipher,
Encrypted,
mkEncryptionCodec,
) where
import Control.Error
import Control.Exception (displayException)
import Control.Monad.IO.Class
import Crypto.Cipher.AES (AES256)
import qualified Crypto.Cipher.AESGCMSIV as AESGCMSIV
import Crypto.Cipher.Types (AuthTag (..), KeySizeSpecifier (KeySizeFixed), cipherInit, cipherKeySize)
import Crypto.Error (CryptoFailable (..))
import qualified Crypto.Random.Types as CRT
import Data.ByteArray (ScrubbedBytes, convert)
import Data.ByteArray.Encoding
import Data.ByteString (ByteString)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.ProtoLens.Encoding (decodeMessage, encodeMessage)
import Temporal.Payload
newtype Key where
Key :: ScrubbedBytes -> Key
deriving newtype (Key -> Key -> Bool
(Key -> Key -> Bool) -> (Key -> Key -> Bool) -> Eq Key
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Key -> Key -> Bool
== :: Key -> Key -> Bool
$c/= :: Key -> Key -> Bool
/= :: Key -> Key -> Bool
Eq, Eq Key
Eq Key =>
(Key -> Key -> Ordering)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Key)
-> (Key -> Key -> Key)
-> Ord Key
Key -> Key -> Bool
Key -> Key -> Ordering
Key -> Key -> Key
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Key -> Key -> Ordering
compare :: Key -> Key -> Ordering
$c< :: Key -> Key -> Bool
< :: Key -> Key -> Bool
$c<= :: Key -> Key -> Bool
<= :: Key -> Key -> Bool
$c> :: Key -> Key -> Bool
> :: Key -> Key -> Bool
$c>= :: Key -> Key -> Bool
>= :: Key -> Key -> Bool
$cmax :: Key -> Key -> Key
max :: Key -> Key -> Key
$cmin :: Key -> Key -> Key
min :: Key -> Key -> Key
Ord, Int -> Key -> ShowS
[Key] -> ShowS
Key -> String
(Int -> Key -> ShowS)
-> (Key -> String) -> ([Key] -> ShowS) -> Show Key
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Key -> ShowS
showsPrec :: Int -> Key -> ShowS
$cshow :: Key -> String
show :: Key -> String
$cshowList :: [Key] -> ShowS
showList :: [Key] -> ShowS
Show)
newtype Cipher where
Cipher :: AES256 -> Temporal.Codec.Encryption.Cipher
keyFromBytes :: ScrubbedBytes -> Key
keyFromBytes :: ScrubbedBytes -> Key
keyFromBytes = ScrubbedBytes -> Key
Key
keyToBase64 :: Key -> ByteString
keyToBase64 :: Key -> ByteString
keyToBase64 (Key ScrubbedBytes
bytes) = Base -> ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
convertToBase Base
Base64 ScrubbedBytes
bytes
keyFromBase64 :: ByteString -> Either String Key
keyFromBase64 :: ByteString -> Either String Key
keyFromBase64 ByteString
bs64 = case Base -> ByteString -> Either String ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64 ByteString
bs64 of
Left String
e -> String -> Either String Key
forall a b. a -> Either a b
Left String
e
Right ScrubbedBytes
bs -> Key -> Either String Key
forall a b. b -> Either a b
Right (Key -> Either String Key) -> Key -> Either String Key
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Key
Key (ScrubbedBytes
bs :: ScrubbedBytes)
genSecretKey :: forall m. (CRT.MonadRandom m) => m Key
genSecretKey :: forall (m :: * -> *). MonadRandom m => m Key
genSecretKey = (ScrubbedBytes -> Key) -> m ScrubbedBytes -> m Key
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ScrubbedBytes -> Key
Key (m ScrubbedBytes -> m Key) -> m ScrubbedBytes -> m Key
forall a b. (a -> b) -> a -> b
$ Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
CRT.getRandomBytes (Int -> m ScrubbedBytes) -> Int -> m ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ case AES256 -> KeySizeSpecifier
forall cipher. Cipher cipher => cipher -> KeySizeSpecifier
cipherKeySize (AES256
forall a. HasCallStack => a
undefined :: AES256) of
KeySizeFixed Int
n -> Int
n
KeySizeSpecifier
_ -> String -> Int
forall a. HasCallStack => String -> a
error String
"AES256 key size should fixed"
initCipher :: Key -> Either String Cipher
initCipher :: Key -> Either String Cipher
initCipher (Key ScrubbedBytes
k) = case ScrubbedBytes -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
forall key. ByteArray key => key -> CryptoFailable AES256
cipherInit ScrubbedBytes
k of
CryptoFailed CryptoError
e -> String -> Either String Cipher
forall a b. a -> Either a b
Left (String -> Either String Cipher) -> String -> Either String Cipher
forall a b. (a -> b) -> a -> b
$ CryptoError -> String
forall a. Show a => a -> String
show CryptoError
e
CryptoPassed AES256
a -> Cipher -> Either String Cipher
forall a b. b -> Either a b
Right (Cipher -> Either String Cipher) -> Cipher -> Either String Cipher
forall a b. (a -> b) -> a -> b
$ AES256 -> Cipher
Cipher AES256
a
data Encrypted = Encrypted
{ Encrypted -> Map ByteString Cipher
encryptionKeys :: Map ByteString Cipher
, Encrypted -> ByteString
defaultKeyName :: ByteString
, Encrypted -> Cipher
defaultKey :: Cipher
}
mkEncryptionCodec :: MonadIO m => (ByteString, Cipher) -> Map ByteString Cipher -> m (Either String Encrypted)
mkEncryptionCodec :: forall (m :: * -> *).
MonadIO m =>
(ByteString, Cipher)
-> Map ByteString Cipher -> m (Either String Encrypted)
mkEncryptionCodec (ByteString
defaultKeyName, Cipher
defaultKey) Map ByteString Cipher
otherKeys = IO (Either String Encrypted) -> m (Either String Encrypted)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either String Encrypted) -> m (Either String Encrypted))
-> IO (Either String Encrypted) -> m (Either String Encrypted)
forall a b. (a -> b) -> a -> b
$ do
Either String Encrypted -> IO (Either String Encrypted)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String Encrypted -> IO (Either String Encrypted))
-> Either String Encrypted -> IO (Either String Encrypted)
forall a b. (a -> b) -> a -> b
$
Encrypted -> Either String Encrypted
forall a b. b -> Either a b
Right (Encrypted -> Either String Encrypted)
-> Encrypted -> Either String Encrypted
forall a b. (a -> b) -> a -> b
$
Encrypted
{ encryptionKeys :: Map ByteString Cipher
encryptionKeys = ByteString
-> Cipher -> Map ByteString Cipher -> Map ByteString Cipher
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ByteString
defaultKeyName Cipher
defaultKey Map ByteString Cipher
otherKeys
, ByteString
Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
..
}
instance Codec Encrypted Payload where
encoding :: Encrypted -> Proxy Payload -> ByteString
encoding Encrypted
_ Proxy Payload
_ = ByteString
"binary/encrypted"
encode :: Encrypted -> Payload -> IO Payload
encode Encrypted {ByteString
Map ByteString Cipher
Cipher
encryptionKeys :: Encrypted -> Map ByteString Cipher
defaultKeyName :: Encrypted -> ByteString
defaultKey :: Encrypted -> Cipher
encryptionKeys :: Map ByteString Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
..} Payload
x = do
if Payload
x.payloadData ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""
then Payload -> IO Payload
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Payload
x
else do
n <- IO Nonce
forall (m :: * -> *). MonadRandom m => m Nonce
AESGCMSIV.generateNonce
let (Cipher k) = defaultKey
(authTag, encrypted) = AESGCMSIV.encrypt k n (mempty :: ByteString) $ encodeMessage $ convertToProtoPayload x
pure $
Payload
encrypted
( Map.fromList
[ ("encryption-key-id", defaultKeyName)
, ("encoding", "binary/encrypted")
, ("cipher", "AESGCMSIV")
, ("nonce", convert n)
, ("auth-tag", convert authTag)
]
)
decode :: Encrypted -> Payload -> IO (Either String Payload)
decode Encrypted {ByteString
Map ByteString Cipher
Cipher
encryptionKeys :: Encrypted -> Map ByteString Cipher
defaultKeyName :: Encrypted -> ByteString
defaultKey :: Encrypted -> Cipher
encryptionKeys :: Map ByteString Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
..} Payload
payload = case (,) (ByteString -> ByteString -> (ByteString, ByteString))
-> Maybe ByteString
-> Maybe (ByteString -> (ByteString, ByteString))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Payload
payload.payloadMetadata Map Text ByteString -> Text -> Maybe ByteString
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Text
"encoding" Maybe (ByteString -> (ByteString, ByteString))
-> Maybe ByteString -> Maybe (ByteString, ByteString)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Payload
payload.payloadMetadata Map Text ByteString -> Text -> Maybe ByteString
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Text
"cipher" of
Just (ByteString
"binary/encrypted", ByteString
"AESGCMSIV") -> do
if Payload
payload.payloadData ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""
then Either String Payload -> IO (Either String Payload)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String Payload -> IO (Either String Payload))
-> Either String Payload -> IO (Either String Payload)
forall a b. (a -> b) -> a -> b
$ String -> Either String Payload
forall a b. a -> Either a b
Left String
"Payload data is missing"
else ExceptT String IO Payload -> IO (Either String Payload)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String IO Payload -> IO (Either String Payload))
-> ExceptT String IO Payload -> IO (Either String Payload)
forall a b. (a -> b) -> a -> b
$ do
keyName <- String -> Maybe ByteString -> ExceptT String IO ByteString
forall (m :: * -> *) e a. Monad m => e -> Maybe a -> ExceptT e m a
tryJust String
"Unable to decrypt Payload without encryption key id'" (Maybe ByteString -> ExceptT String IO ByteString)
-> Maybe ByteString -> ExceptT String IO ByteString
forall a b. (a -> b) -> a -> b
$ Payload
payload.payloadMetadata Map Text ByteString -> Text -> Maybe ByteString
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Text
"encryption-key-id"
(Cipher k) <- tryJust ("Could not find encryption key: " <> show keyName) $ encryptionKeys Map.!? keyName
rawNonce <- tryJust "Unable to decrypt Payload without nonce" $ payload.payloadMetadata Map.!? "nonce"
n <- case AESGCMSIV.nonce rawNonce of
CryptoPassed Nonce
n -> Nonce -> ExceptT String IO Nonce
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Nonce
n
CryptoFailed CryptoError
e -> String -> ExceptT String IO Nonce
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ExceptT String IO Nonce)
-> String -> ExceptT String IO Nonce
forall a b. (a -> b) -> a -> b
$ CryptoError -> String
forall e. Exception e => e -> String
displayException CryptoError
e
authTag <- tryJust "Unable to decrypt Payload without auth tag" $ payload.payloadMetadata Map.!? "auth-tag"
decrypted <- tryJust "Unable to decrypt Payload" $ AESGCMSIV.decrypt k n (mempty :: ByteString) payload.payloadData (AuthTag $ convert authTag)
p <- tryRight $ decodeMessage decrypted
pure $! convertFromProtoPayload p
Maybe (ByteString, ByteString)
_ -> Either String Payload -> IO (Either String Payload)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String Payload -> IO (Either String Payload))
-> Either String Payload -> IO (Either String Payload)
forall a b. (a -> b) -> a -> b
$ Payload -> Either String Payload
forall a b. b -> Either a b
Right Payload
payload