{-# LANGUAGE CPP #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}

module Temporal.Workflow.Saga (
  SagaT,
  runSaga,
  compensated,
) where

import Control.Monad


#if MIN_VERSION_mtl(2,3,0)
import Control.Monad.Accum
import Control.Monad.Select
#endif
import Control.Applicative
import Control.Monad.Catch (MonadCatch)
import qualified Control.Monad.Catch as Catch
import Control.Monad.Cont
import Control.Monad.Error.Class (MonadError)
import Control.Monad.Fix (MonadFix)
import Control.Monad.Logger (MonadLogger, logError)
import Control.Monad.RWS (MonadReader)
import Control.Monad.State
import Control.Monad.Writer
import Data.Functor.Contravariant
import qualified Data.Text as T


{- | The saga pattern is a failure management pattern that helps establish consistency in distributed applications,
and coordinates transactions between multiple services to attempt to maintain data consistency.

If you’re wondering if the saga pattern is right for your scenario, ask yourself:

Does your logic involve multiple steps, some of which span machines, services, shards, or databases, for which partial execution is undesirable?

Turns out, this is exactly where sagas are useful. Maybe you are checking inventory, charging a user’s credit card, and then fulfilling the order.
Maybe you are managing a supply chain.

The saga pattern is helpful because it basically functions as a state machine storing program progress,
preventing multiple credit card charges, reverting if necessary, and knowing exactly how to safely resume
in a consistent state in the event of power loss.

There are many “do it all, or don’t bother” software applications in the real-world:

- If you successfully charge the user for an item but your fulfillment service reports that the item is out of stock, you’re going to have upset
  users if you don’t refund the charge. If you have the opposite problem and accidentally deliver items “for free,” you’ll be out of business.
- If the machine coordinating a machine learning data processing pipeline crashes but the follower machines carry on processing the data with
  nowhere to report their data to, you may have a very expensive compute resources bill on your hands.

In all of these cases having some sort of “progress tracking” and compensation code to deal with these “do-it-all-or-don’t-do-any-of-it” tasks is exactly what the saga pattern provides. In saga parlance, these sorts of “all or nothing” tasks are called long-running transactions. This doesn’t necessarily mean such actions run for a “long” time, just that they require more steps in
logical time than something running locally interacting with a single database.

A saga is composed of two parts:

1. Defined behavior for “going backwards” if you need to “undo” something (i.e., compensations)
2. Behavior for striving towards forward progress (i.e., saving state to know where to recover
   from in the face of failure). This second part is often called the “orchestration logic” of the saga.
   For Temporal, execution of the orchestration logic is handled by the Temporal server, and the saga
   monad transformer here simply needs to handle the compensation logic.

Lastly, note that compensation actions are still subject to the same restrictions as any other workflow code.
This means that compensation Actions run within a Workflow monad can be timed out, and they can be retried.
Make sure in this case that whatever retry / timeout policies you have in place for your workflow are appropriate
for the compensation actions you are running.

For more information on the saga pattern, see the following resources:

- [Saga Pattern Made Easy](https://temporal.io/blog/saga-pattern-made-easy-with-temporal/)
- [The Saga Pattern in Distributed Systems](https://www.cs.cornell.edu/andru/cs711/2002fa/reading/sagas.pdf)
-}
newtype SagaT m a = SagaT {forall (m :: * -> *) a. SagaT m a -> StateT [m ()] m a
unSagaT :: StateT [m ()] m a}
  deriving newtype ((forall a b. (a -> b) -> SagaT m a -> SagaT m b)
-> (forall a b. a -> SagaT m b -> SagaT m a) -> Functor (SagaT m)
forall a b. a -> SagaT m b -> SagaT m a
forall a b. (a -> b) -> SagaT m a -> SagaT m b
forall (m :: * -> *) a b. Functor m => a -> SagaT m b -> SagaT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SagaT m a -> SagaT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SagaT m a -> SagaT m b
fmap :: forall a b. (a -> b) -> SagaT m a -> SagaT m b
$c<$ :: forall (m :: * -> *) a b. Functor m => a -> SagaT m b -> SagaT m a
<$ :: forall a b. a -> SagaT m b -> SagaT m a
Functor, Functor (SagaT m)
Functor (SagaT m) =>
(forall a. a -> SagaT m a)
-> (forall a b. SagaT m (a -> b) -> SagaT m a -> SagaT m b)
-> (forall a b c.
    (a -> b -> c) -> SagaT m a -> SagaT m b -> SagaT m c)
-> (forall a b. SagaT m a -> SagaT m b -> SagaT m b)
-> (forall a b. SagaT m a -> SagaT m b -> SagaT m a)
-> Applicative (SagaT m)
forall a. a -> SagaT m a
forall a b. SagaT m a -> SagaT m b -> SagaT m a
forall a b. SagaT m a -> SagaT m b -> SagaT m b
forall a b. SagaT m (a -> b) -> SagaT m a -> SagaT m b
forall a b c. (a -> b -> c) -> SagaT m a -> SagaT m b -> SagaT m c
forall (m :: * -> *). Monad m => Functor (SagaT m)
forall (m :: * -> *) a. Monad m => a -> SagaT m a
forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> SagaT m b -> SagaT m a
forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> SagaT m b -> SagaT m b
forall (m :: * -> *) a b.
Monad m =>
SagaT m (a -> b) -> SagaT m a -> SagaT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SagaT m a -> SagaT m b -> SagaT m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> SagaT m a
pure :: forall a. a -> SagaT m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SagaT m (a -> b) -> SagaT m a -> SagaT m b
<*> :: forall a b. SagaT m (a -> b) -> SagaT m a -> SagaT m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> SagaT m a -> SagaT m b -> SagaT m c
liftA2 :: forall a b c. (a -> b -> c) -> SagaT m a -> SagaT m b -> SagaT m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> SagaT m b -> SagaT m b
*> :: forall a b. SagaT m a -> SagaT m b -> SagaT m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> SagaT m b -> SagaT m a
<* :: forall a b. SagaT m a -> SagaT m b -> SagaT m a
Applicative, Applicative (SagaT m)
Applicative (SagaT m) =>
(forall a b. SagaT m a -> (a -> SagaT m b) -> SagaT m b)
-> (forall a b. SagaT m a -> SagaT m b -> SagaT m b)
-> (forall a. a -> SagaT m a)
-> Monad (SagaT m)
forall a. a -> SagaT m a
forall a b. SagaT m a -> SagaT m b -> SagaT m b
forall a b. SagaT m a -> (a -> SagaT m b) -> SagaT m b
forall (m :: * -> *). Monad m => Applicative (SagaT m)
forall (m :: * -> *) a. Monad m => a -> SagaT m a
forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> SagaT m b -> SagaT m b
forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> (a -> SagaT m b) -> SagaT m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> (a -> SagaT m b) -> SagaT m b
>>= :: forall a b. SagaT m a -> (a -> SagaT m b) -> SagaT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SagaT m a -> SagaT m b -> SagaT m b
>> :: forall a b. SagaT m a -> SagaT m b -> SagaT m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> SagaT m a
return :: forall a. a -> SagaT m a
Monad, Monad (SagaT m)
Monad (SagaT m) =>
(forall a. IO a -> SagaT m a) -> MonadIO (SagaT m)
forall a. IO a -> SagaT m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (SagaT m)
forall (m :: * -> *) a. MonadIO m => IO a -> SagaT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> SagaT m a
liftIO :: forall a. IO a -> SagaT m a
MonadIO)


runSaga :: (MonadCatch m, MonadLogger m) => (Catch.SomeException -> m ()) -> SagaT m a -> m a
runSaga :: forall (m :: * -> *) a.
(MonadCatch m, MonadLogger m) =>
(SomeException -> m ()) -> SagaT m a -> m a
runSaga SomeException -> m ()
compensationExceptionHandler SagaT m a
m = (StateT [m ()] m a -> [m ()] -> m a)
-> [m ()] -> StateT [m ()] m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT [m ()] m a -> [m ()] -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT [] (StateT [m ()] m a -> m a) -> StateT [m ()] m a -> m a
forall a b. (a -> b) -> a -> b
$ do
  SagaT m a -> StateT [m ()] m a
forall (m :: * -> *) a. SagaT m a -> StateT [m ()] m a
unSagaT (SagaT m a
m SagaT m a -> SagaT m () -> SagaT m a
forall (m :: * -> *) a b.
(HasCallStack, MonadCatch m) =>
m a -> m b -> m a
`Catch.onException` (SomeException -> m ()) -> SagaT m ()
forall (m :: * -> *).
(MonadCatch m, MonadLogger m) =>
(SomeException -> m ()) -> SagaT m ()
compensate SomeException -> m ()
compensationExceptionHandler)


instance MonadTrans SagaT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> SagaT m a
lift = StateT [m ()] m a -> SagaT m a
forall (m :: * -> *) a. StateT [m ()] m a -> SagaT m a
SagaT (StateT [m ()] m a -> SagaT m a)
-> (m a -> StateT [m ()] m a) -> m a -> SagaT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT [m ()] m a
forall (m :: * -> *) a. Monad m => m a -> StateT [m ()] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift


instance MonadState s m => MonadState s (SagaT m) where
  get :: SagaT m s
get = m s -> SagaT m s
forall (m :: * -> *) a. Monad m => m a -> SagaT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> SagaT m ()
put = m () -> SagaT m ()
forall (m :: * -> *) a. Monad m => m a -> SagaT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SagaT m ()) -> (s -> m ()) -> s -> SagaT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put


deriving newtype instance MonadError e m => MonadError e (SagaT m)


deriving newtype instance MonadReader r m => MonadReader r (SagaT m)
#if MIN_VERSION_mtl(2,3,0)
deriving newtype instance MonadAccum w m => MonadAccum w (SagaT m)
deriving newtype instance MonadSelect w m => MonadSelect w (SagaT m)
#endif


deriving newtype instance MonadWriter w m => MonadWriter w (SagaT m)


deriving newtype instance MonadFail m => MonadFail (SagaT m)


deriving newtype instance MonadFix m => MonadFix (SagaT m)


deriving newtype instance Contravariant m => Contravariant (SagaT m)


deriving newtype instance MonadPlus m => Alternative (SagaT m)


deriving newtype instance MonadPlus m => MonadPlus (SagaT m)


deriving newtype instance MonadCont m => MonadCont (SagaT m)


deriving newtype instance MonadLogger m => MonadLogger (SagaT m)


deriving newtype instance Catch.MonadThrow m => Catch.MonadThrow (SagaT m)


deriving newtype instance Catch.MonadCatch m => Catch.MonadCatch (SagaT m)


-- | Run all compensation actions that have been added to the saga.
compensate :: (MonadCatch m, MonadLogger m) => (Catch.SomeException -> m ()) -> SagaT m ()
compensate :: forall (m :: * -> *).
(MonadCatch m, MonadLogger m) =>
(SomeException -> m ()) -> SagaT m ()
compensate SomeException -> m ()
compensationExceptionHandler = StateT [m ()] m () -> SagaT m ()
forall (m :: * -> *) a. StateT [m ()] m a -> SagaT m a
SagaT (StateT [m ()] m () -> SagaT m ())
-> StateT [m ()] m () -> SagaT m ()
forall a b. (a -> b) -> a -> b
$ do
  actions <- StateT [m ()] m [m ()]
forall s (m :: * -> *). MonadState s m => m s
get
  put []
  lift $ forM_ actions $ \m ()
action -> do
    res <- m () -> m (Either SomeException ())
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
Catch.try m ()
action
    case res of
      Left (SomeException
e :: Catch.SomeException) -> do
        m () -> (SomeException -> m ()) -> m ()
forall (m :: * -> *) a.
(HasCallStack, MonadCatch m) =>
m a -> (SomeException -> m a) -> m a
Catch.catchAll (SomeException -> m ()
compensationExceptionHandler SomeException
e) ((SomeException -> m ()) -> m ())
-> (SomeException -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \SomeException
e' -> do
          $(logError) (LogSource -> m ()) -> LogSource -> m ()
forall a b. (a -> b) -> a -> b
$ String -> LogSource
T.pack (String -> LogSource) -> String -> LogSource
forall a b. (a -> b) -> a -> b
$ String
"Saga compensation error handler threw exception: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
e'
      Right () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()


addCompensation :: Monad m => m () -> SagaT m ()
addCompensation :: forall (m :: * -> *). Monad m => m () -> SagaT m ()
addCompensation m ()
action = StateT [m ()] m () -> SagaT m ()
forall (m :: * -> *) a. StateT [m ()] m a -> SagaT m a
SagaT (StateT [m ()] m () -> SagaT m ())
-> StateT [m ()] m () -> SagaT m ()
forall a b. (a -> b) -> a -> b
$ ([m ()] -> [m ()]) -> StateT [m ()] m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (m ()
action m () -> [m ()] -> [m ()]
forall a. a -> [a] -> [a]
:)


compensated
  :: (MonadCatch m)
  => m ()
  -- ^ The compensation action to run if this saga step fails.
  --
  -- This action will be run in the event of any subsequent failure,
  -- not just the failure of this saga step.
  --
  -- Saga compensation actions will be run in reverse order
  -- of their addition to the saga.
  --
  -- Lastly, note that all sync exceptions thrown by the compensation
  -- action will be swallowed.
  -> m a
  -- ^ The saga step to run.
  -> SagaT m a
compensated :: forall (m :: * -> *) a. MonadCatch m => m () -> m a -> SagaT m a
compensated m ()
compensation m a
step = do
  res <- m a -> SagaT m a
forall (m :: * -> *) a. Monad m => m a -> SagaT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m a
step
  addCompensation compensation
  pure res