{-# LANGUAGE DerivingStrategies #-}

module Temporal.Codec.Encryption (
  Key,
  keyFromBytes,
  Cipher,
  keyToBase64,
  keyFromBase64,
  genSecretKey,
  initCipher,
  Encrypted,
  mkEncryptionCodec,
) where

import Control.Error
import Control.Exception (displayException)
import Control.Monad.IO.Class
import Crypto.Cipher.AES (AES256)
import qualified Crypto.Cipher.AESGCMSIV as AESGCMSIV
import Crypto.Cipher.Types (AuthTag (..), KeySizeSpecifier (KeySizeFixed), cipherInit, cipherKeySize)
import Crypto.Error (CryptoFailable (..))
import qualified Crypto.Random.Types as CRT
import Data.ByteArray (ScrubbedBytes, convert)
import Data.ByteArray.Encoding
import Data.ByteString (ByteString)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.ProtoLens.Encoding (decodeMessage, encodeMessage)
import Temporal.Payload


-- An unencoded AES-256 encryption key
newtype Key where
  Key :: ScrubbedBytes -> Key
  deriving newtype (Key -> Key -> Bool
(Key -> Key -> Bool) -> (Key -> Key -> Bool) -> Eq Key
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Key -> Key -> Bool
== :: Key -> Key -> Bool
$c/= :: Key -> Key -> Bool
/= :: Key -> Key -> Bool
Eq, Eq Key
Eq Key =>
(Key -> Key -> Ordering)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Bool)
-> (Key -> Key -> Key)
-> (Key -> Key -> Key)
-> Ord Key
Key -> Key -> Bool
Key -> Key -> Ordering
Key -> Key -> Key
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Key -> Key -> Ordering
compare :: Key -> Key -> Ordering
$c< :: Key -> Key -> Bool
< :: Key -> Key -> Bool
$c<= :: Key -> Key -> Bool
<= :: Key -> Key -> Bool
$c> :: Key -> Key -> Bool
> :: Key -> Key -> Bool
$c>= :: Key -> Key -> Bool
>= :: Key -> Key -> Bool
$cmax :: Key -> Key -> Key
max :: Key -> Key -> Key
$cmin :: Key -> Key -> Key
min :: Key -> Key -> Key
Ord, Int -> Key -> ShowS
[Key] -> ShowS
Key -> String
(Int -> Key -> ShowS)
-> (Key -> String) -> ([Key] -> ShowS) -> Show Key
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Key -> ShowS
showsPrec :: Int -> Key -> ShowS
$cshow :: Key -> String
show :: Key -> String
$cshowList :: [Key] -> ShowS
showList :: [Key] -> ShowS
Show)


newtype Cipher where
  Cipher :: AES256 -> Temporal.Codec.Encryption.Cipher


keyFromBytes :: ScrubbedBytes -> Key
keyFromBytes :: ScrubbedBytes -> Key
keyFromBytes = ScrubbedBytes -> Key
Key


keyToBase64 :: Key -> ByteString
keyToBase64 :: Key -> ByteString
keyToBase64 (Key ScrubbedBytes
bytes) = Base -> ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
convertToBase Base
Base64 ScrubbedBytes
bytes


keyFromBase64 :: ByteString -> Either String Key
keyFromBase64 :: ByteString -> Either String Key
keyFromBase64 ByteString
bs64 = case Base -> ByteString -> Either String ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64 ByteString
bs64 of
  Left String
e -> String -> Either String Key
forall a b. a -> Either a b
Left String
e
  Right ScrubbedBytes
bs -> Key -> Either String Key
forall a b. b -> Either a b
Right (Key -> Either String Key) -> Key -> Either String Key
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Key
Key (ScrubbedBytes
bs :: ScrubbedBytes)


-- | Generates a string of bytes (key) of a specific length for a given block cipher
genSecretKey :: forall m. (CRT.MonadRandom m) => m Key
genSecretKey :: forall (m :: * -> *). MonadRandom m => m Key
genSecretKey = (ScrubbedBytes -> Key) -> m ScrubbedBytes -> m Key
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ScrubbedBytes -> Key
Key (m ScrubbedBytes -> m Key) -> m ScrubbedBytes -> m Key
forall a b. (a -> b) -> a -> b
$ Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
CRT.getRandomBytes (Int -> m ScrubbedBytes) -> Int -> m ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ case AES256 -> KeySizeSpecifier
forall cipher. Cipher cipher => cipher -> KeySizeSpecifier
cipherKeySize (AES256
forall a. HasCallStack => a
undefined :: AES256) of
  KeySizeFixed Int
n -> Int
n
  KeySizeSpecifier
_ -> String -> Int
forall a. HasCallStack => String -> a
error String
"AES256 key size should fixed"


-- | Initialize an AES256 cipher
initCipher :: Key -> Either String Cipher
initCipher :: Key -> Either String Cipher
initCipher (Key ScrubbedBytes
k) = case ScrubbedBytes -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
forall key. ByteArray key => key -> CryptoFailable AES256
cipherInit ScrubbedBytes
k of
  CryptoFailed CryptoError
e -> String -> Either String Cipher
forall a b. a -> Either a b
Left (String -> Either String Cipher) -> String -> Either String Cipher
forall a b. (a -> b) -> a -> b
$ CryptoError -> String
forall a. Show a => a -> String
show CryptoError
e
  CryptoPassed AES256
a -> Cipher -> Either String Cipher
forall a b. b -> Either a b
Right (Cipher -> Either String Cipher) -> Cipher -> Either String Cipher
forall a b. (a -> b) -> a -> b
$ AES256 -> Cipher
Cipher AES256
a


data Encrypted = Encrypted
  { Encrypted -> Map ByteString Cipher
encryptionKeys :: Map ByteString Cipher
  , Encrypted -> ByteString
defaultKeyName :: ByteString
  , Encrypted -> Cipher
defaultKey :: Cipher
  -- TODO, fetchKey operation to support a KMS (key management system)
  }


mkEncryptionCodec :: MonadIO m => (ByteString, Cipher) -> Map ByteString Cipher -> m (Either String Encrypted)
mkEncryptionCodec :: forall (m :: * -> *).
MonadIO m =>
(ByteString, Cipher)
-> Map ByteString Cipher -> m (Either String Encrypted)
mkEncryptionCodec (ByteString
defaultKeyName, Cipher
defaultKey) Map ByteString Cipher
otherKeys = IO (Either String Encrypted) -> m (Either String Encrypted)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either String Encrypted) -> m (Either String Encrypted))
-> IO (Either String Encrypted) -> m (Either String Encrypted)
forall a b. (a -> b) -> a -> b
$ do
  Either String Encrypted -> IO (Either String Encrypted)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String Encrypted -> IO (Either String Encrypted))
-> Either String Encrypted -> IO (Either String Encrypted)
forall a b. (a -> b) -> a -> b
$
    Encrypted -> Either String Encrypted
forall a b. b -> Either a b
Right (Encrypted -> Either String Encrypted)
-> Encrypted -> Either String Encrypted
forall a b. (a -> b) -> a -> b
$
      Encrypted
        { encryptionKeys :: Map ByteString Cipher
encryptionKeys = ByteString
-> Cipher -> Map ByteString Cipher -> Map ByteString Cipher
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ByteString
defaultKeyName Cipher
defaultKey Map ByteString Cipher
otherKeys
        , ByteString
Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
..
        }


instance Codec Encrypted Payload where
  encoding :: Encrypted -> Proxy Payload -> ByteString
encoding Encrypted
_ Proxy Payload
_ = ByteString
"binary/encrypted"
  encode :: Encrypted -> Payload -> IO Payload
encode Encrypted {ByteString
Map ByteString Cipher
Cipher
encryptionKeys :: Encrypted -> Map ByteString Cipher
defaultKeyName :: Encrypted -> ByteString
defaultKey :: Encrypted -> Cipher
encryptionKeys :: Map ByteString Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
..} Payload
x = do
    if Payload
x.payloadData ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""
      then Payload -> IO Payload
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Payload
x
      else do
        n <- IO Nonce
forall (m :: * -> *). MonadRandom m => m Nonce
AESGCMSIV.generateNonce
        let (Cipher k) = defaultKey
            (authTag, encrypted) = AESGCMSIV.encrypt k n (mempty :: ByteString) $ encodeMessage $ convertToProtoPayload x
        pure $
          Payload
            encrypted
            ( Map.fromList
                [ ("encryption-key-id", defaultKeyName)
                , ("encoding", "binary/encrypted")
                , ("cipher", "AESGCMSIV")
                , ("nonce", convert n)
                , ("auth-tag", convert authTag)
                ]
            )
  decode :: Encrypted -> Payload -> IO (Either String Payload)
decode Encrypted {ByteString
Map ByteString Cipher
Cipher
encryptionKeys :: Encrypted -> Map ByteString Cipher
defaultKeyName :: Encrypted -> ByteString
defaultKey :: Encrypted -> Cipher
encryptionKeys :: Map ByteString Cipher
defaultKeyName :: ByteString
defaultKey :: Cipher
..} Payload
payload = case (,) (ByteString -> ByteString -> (ByteString, ByteString))
-> Maybe ByteString
-> Maybe (ByteString -> (ByteString, ByteString))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Payload
payload.payloadMetadata Map Text ByteString -> Text -> Maybe ByteString
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Text
"encoding" Maybe (ByteString -> (ByteString, ByteString))
-> Maybe ByteString -> Maybe (ByteString, ByteString)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Payload
payload.payloadMetadata Map Text ByteString -> Text -> Maybe ByteString
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Text
"cipher" of
    Just (ByteString
"binary/encrypted", ByteString
"AESGCMSIV") -> do
      if Payload
payload.payloadData ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""
        then Either String Payload -> IO (Either String Payload)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String Payload -> IO (Either String Payload))
-> Either String Payload -> IO (Either String Payload)
forall a b. (a -> b) -> a -> b
$ String -> Either String Payload
forall a b. a -> Either a b
Left String
"Payload data is missing"
        else ExceptT String IO Payload -> IO (Either String Payload)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String IO Payload -> IO (Either String Payload))
-> ExceptT String IO Payload -> IO (Either String Payload)
forall a b. (a -> b) -> a -> b
$ do
          keyName <- String -> Maybe ByteString -> ExceptT String IO ByteString
forall (m :: * -> *) e a. Monad m => e -> Maybe a -> ExceptT e m a
tryJust String
"Unable to decrypt Payload without encryption key id'" (Maybe ByteString -> ExceptT String IO ByteString)
-> Maybe ByteString -> ExceptT String IO ByteString
forall a b. (a -> b) -> a -> b
$ Payload
payload.payloadMetadata Map Text ByteString -> Text -> Maybe ByteString
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Text
"encryption-key-id"
          (Cipher k) <- tryJust ("Could not find encryption key: " <> show keyName) $ encryptionKeys Map.!? keyName
          rawNonce <- tryJust "Unable to decrypt Payload without nonce" $ payload.payloadMetadata Map.!? "nonce"
          n <- case AESGCMSIV.nonce rawNonce of
            CryptoPassed Nonce
n -> Nonce -> ExceptT String IO Nonce
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Nonce
n
            CryptoFailed CryptoError
e -> String -> ExceptT String IO Nonce
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (String -> ExceptT String IO Nonce)
-> String -> ExceptT String IO Nonce
forall a b. (a -> b) -> a -> b
$ CryptoError -> String
forall e. Exception e => e -> String
displayException CryptoError
e
          authTag <- tryJust "Unable to decrypt Payload without auth tag" $ payload.payloadMetadata Map.!? "auth-tag"
          decrypted <- tryJust "Unable to decrypt Payload" $ AESGCMSIV.decrypt k n (mempty :: ByteString) payload.payloadData (AuthTag $ convert authTag)
          p <- tryRight $ decodeMessage decrypted
          pure $! convertFromProtoPayload p
    Maybe (ByteString, ByteString)
_ -> Either String Payload -> IO (Either String Payload)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String Payload -> IO (Either String Payload))
-> Either String Payload -> IO (Either String Payload)
forall a b. (a -> b) -> a -> b
$ Payload -> Either String Payload
forall a b. b -> Either a b
Right Payload
payload