module Temporal.Coroutine where

import Control.Monad
import qualified Control.Monad.Catch as Catch
import Control.Monad.Trans


newtype Await x y = Await (x -> y)


instance Functor (Await x) where
  fmap :: forall a b. (a -> b) -> Await x a -> Await x b
fmap a -> b
f (Await x -> a
g) = (x -> b) -> Await x b
forall x y. (x -> y) -> Await x y
Await (a -> b
f (a -> b) -> (x -> a) -> x -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> a
g)


-- | Suspending, resumable monadic computations.
newtype Coroutine s m r = Coroutine
  { forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume :: m (Either (s (Coroutine s m r)) r)
  -- ^ Run the next step of a `Coroutine` computation. The result of the step execution will be either a suspension or
  -- the final coroutine result.
  }


instance (Functor s, Functor m) => Functor (Coroutine s m) where
  fmap :: forall a b. (a -> b) -> Coroutine s m a -> Coroutine s m b
fmap a -> b
f Coroutine s m a
t = m (Either (s (Coroutine s m b)) b) -> Coroutine s m b
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine ((Either (s (Coroutine s m a)) a -> Either (s (Coroutine s m b)) b)
-> m (Either (s (Coroutine s m a)) a)
-> m (Either (s (Coroutine s m b)) b)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b)
-> Either (s (Coroutine s m a)) a -> Either (s (Coroutine s m b)) b
forall {f :: * -> *} {f :: * -> *} {a} {b}.
(Functor f, Functor f) =>
(a -> b) -> Either (f (f a)) a -> Either (f (f b)) b
apply a -> b
f) (Coroutine s m a -> m (Either (s (Coroutine s m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume Coroutine s m a
t))
    where
      apply :: (a -> b) -> Either (f (f a)) a -> Either (f (f b)) b
apply a -> b
fc (Right a
x) = b -> Either (f (f b)) b
forall a b. b -> Either a b
Right (a -> b
fc a
x)
      apply a -> b
fc (Left f (f a)
s) = f (f b) -> Either (f (f b)) b
forall a b. a -> Either a b
Left ((f a -> f b) -> f (f a) -> f (f b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> f a -> f b
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
fc) f (f a)
s)


instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
  pure :: forall a. a -> Coroutine s m a
pure a
x = m (Either (s (Coroutine s m a)) a) -> Coroutine s m a
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (Either (s (Coroutine s m a)) a
-> m (Either (s (Coroutine s m a)) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Either (s (Coroutine s m a)) a
forall a b. b -> Either a b
Right a
x))
  <*> :: forall a b.
Coroutine s m (a -> b) -> Coroutine s m a -> Coroutine s m b
(<*>) = Coroutine s m (a -> b) -> Coroutine s m a -> Coroutine s m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
  Coroutine s m a
t *> :: forall a b. Coroutine s m a -> Coroutine s m b -> Coroutine s m b
*> Coroutine s m b
f = m (Either (s (Coroutine s m b)) b) -> Coroutine s m b
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (Coroutine s m a -> m (Either (s (Coroutine s m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume Coroutine s m a
t m (Either (s (Coroutine s m a)) a)
-> (Either (s (Coroutine s m a)) a
    -> m (Either (s (Coroutine s m b)) b))
-> m (Either (s (Coroutine s m b)) b)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Coroutine s m b
-> Either (s (Coroutine s m a)) a
-> m (Either (s (Coroutine s m b)) b)
forall {m :: * -> *} {s :: * -> *} {r} {a} {b}.
(Monad m, Functor s) =>
Coroutine s m r
-> Either (s (Coroutine s m a)) b
-> m (Either (s (Coroutine s m r)) r)
apply Coroutine s m b
f)
    where
      apply :: Coroutine s m r
-> Either (s (Coroutine s m a)) b
-> m (Either (s (Coroutine s m r)) r)
apply Coroutine s m r
fc (Right b
_) = Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume Coroutine s m r
fc
      apply Coroutine s m r
fc (Left s (Coroutine s m a)
s) = Either (s (Coroutine s m r)) r
-> m (Either (s (Coroutine s m r)) r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (s (Coroutine s m r) -> Either (s (Coroutine s m r)) r
forall a b. a -> Either a b
Left ((Coroutine s m a -> Coroutine s m r)
-> s (Coroutine s m a) -> s (Coroutine s m r)
forall a b. (a -> b) -> s a -> s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Coroutine s m a -> Coroutine s m r -> Coroutine s m r
forall a b. Coroutine s m a -> Coroutine s m b -> Coroutine s m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Coroutine s m r
fc) s (Coroutine s m a)
s))


instance (Functor s, Monad m) => Monad (Coroutine s m) where
  return :: forall a. a -> Coroutine s m a
return = a -> Coroutine s m a
forall a. a -> Coroutine s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  Coroutine s m a
t >>= :: forall a b.
Coroutine s m a -> (a -> Coroutine s m b) -> Coroutine s m b
>>= a -> Coroutine s m b
f = m (Either (s (Coroutine s m b)) b) -> Coroutine s m b
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (Coroutine s m a -> m (Either (s (Coroutine s m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume Coroutine s m a
t m (Either (s (Coroutine s m a)) a)
-> (Either (s (Coroutine s m a)) a
    -> m (Either (s (Coroutine s m b)) b))
-> m (Either (s (Coroutine s m b)) b)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> Coroutine s m b)
-> Either (s (Coroutine s m a)) a
-> m (Either (s (Coroutine s m b)) b)
forall {m :: * -> *} {s :: * -> *} {a} {r}.
(Monad m, Functor s) =>
(a -> Coroutine s m r)
-> Either (s (Coroutine s m a)) a
-> m (Either (s (Coroutine s m r)) r)
apply a -> Coroutine s m b
f)
    where
      apply :: (a -> Coroutine s m r)
-> Either (s (Coroutine s m a)) a
-> m (Either (s (Coroutine s m r)) r)
apply a -> Coroutine s m r
fc (Right a
x) = Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume (a -> Coroutine s m r
fc a
x)
      apply a -> Coroutine s m r
fc (Left s (Coroutine s m a)
s) = Either (s (Coroutine s m r)) r
-> m (Either (s (Coroutine s m r)) r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (s (Coroutine s m r) -> Either (s (Coroutine s m r)) r
forall a b. a -> Either a b
Left ((Coroutine s m a -> Coroutine s m r)
-> s (Coroutine s m a) -> s (Coroutine s m r)
forall a b. (a -> b) -> s a -> s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Coroutine s m a -> (a -> Coroutine s m r) -> Coroutine s m r
forall a b.
Coroutine s m a -> (a -> Coroutine s m b) -> Coroutine s m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> Coroutine s m r
fc) s (Coroutine s m a)
s))
  >> :: forall a b. Coroutine s m a -> Coroutine s m b -> Coroutine s m b
(>>) = Coroutine s m a -> Coroutine s m b -> Coroutine s m b
forall a b. Coroutine s m a -> Coroutine s m b -> Coroutine s m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
(*>)


instance (Functor s, MonadFail m) => MonadFail (Coroutine s m) where
  fail :: forall a. String -> Coroutine s m a
fail String
msg = m (Either (s (Coroutine s m a)) a) -> Coroutine s m a
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (a -> Either (s (Coroutine s m a)) a
forall a b. b -> Either a b
Right (a -> Either (s (Coroutine s m a)) a)
-> m a -> m (Either (s (Coroutine s m a)) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
msg)


instance Functor s => MonadTrans (Coroutine s) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> Coroutine s m a
lift = m (Either (s (Coroutine s m a)) a) -> Coroutine s m a
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (m (Either (s (Coroutine s m a)) a) -> Coroutine s m a)
-> (m a -> m (Either (s (Coroutine s m a)) a))
-> m a
-> Coroutine s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Either (s (Coroutine s m a)) a)
-> m a -> m (Either (s (Coroutine s m a)) a)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Either (s (Coroutine s m a)) a
forall a b. b -> Either a b
Right


instance (Functor s, MonadIO m) => MonadIO (Coroutine s m) where
  liftIO :: forall a. IO a -> Coroutine s m a
liftIO = m a -> Coroutine s m a
forall (m :: * -> *) a. Monad m => m a -> Coroutine s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Coroutine s m a)
-> (IO a -> m a) -> IO a -> Coroutine s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO


instance (Functor s, Catch.MonadThrow m) => Catch.MonadThrow (Coroutine s m) where
  throwM :: forall e a. (HasCallStack, Exception e) => e -> Coroutine s m a
throwM = m a -> Coroutine s m a
forall (m :: * -> *) a. Monad m => m a -> Coroutine s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Coroutine s m a) -> (e -> m a) -> e -> Coroutine s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
Catch.throwM


instance (Functor s, Catch.MonadCatch m) => Catch.MonadCatch (Coroutine s m) where
  catch :: forall e a.
(HasCallStack, Exception e) =>
Coroutine s m a -> (e -> Coroutine s m a) -> Coroutine s m a
catch (Coroutine m (Either (s (Coroutine s m a)) a)
t) e -> Coroutine s m a
h = m (Either (s (Coroutine s m a)) a) -> Coroutine s m a
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (m (Either (s (Coroutine s m a)) a) -> Coroutine s m a)
-> m (Either (s (Coroutine s m a)) a) -> Coroutine s m a
forall a b. (a -> b) -> a -> b
$ do
    r <- m (Either (s (Coroutine s m a)) a)
-> (e -> m (Either (s (Coroutine s m a)) a))
-> m (Either (s (Coroutine s m a)) a)
forall e a. (HasCallStack, Exception e) => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
Catch.catch m (Either (s (Coroutine s m a)) a)
t (Coroutine s m a -> m (Either (s (Coroutine s m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume (Coroutine s m a -> m (Either (s (Coroutine s m a)) a))
-> (e -> Coroutine s m a)
-> e
-> m (Either (s (Coroutine s m a)) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> Coroutine s m a
h)
    case r of
      Right a
x -> Either (s (Coroutine s m a)) a
-> m (Either (s (Coroutine s m a)) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Either (s (Coroutine s m a)) a
forall a b. b -> Either a b
Right a
x)
      Left s (Coroutine s m a)
s -> Either (s (Coroutine s m a)) a
-> m (Either (s (Coroutine s m a)) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (s (Coroutine s m a) -> Either (s (Coroutine s m a)) a
forall a b. a -> Either a b
Left ((Coroutine s m a -> Coroutine s m a)
-> s (Coroutine s m a) -> s (Coroutine s m a)
forall a b. (a -> b) -> s a -> s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Coroutine s m a -> (e -> Coroutine s m a) -> Coroutine s m a
forall e a.
(HasCallStack, Exception e) =>
Coroutine s m a -> (e -> Coroutine s m a) -> Coroutine s m a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
`Catch.catch` e -> Coroutine s m a
h) s (Coroutine s m a)
s))


suspend :: (Monad m) => s (Coroutine s m x) -> Coroutine s m x
suspend :: forall (m :: * -> *) (s :: * -> *) x.
Monad m =>
s (Coroutine s m x) -> Coroutine s m x
suspend s (Coroutine s m x)
s = m (Either (s (Coroutine s m x)) x) -> Coroutine s m x
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (Either (s (Coroutine s m x)) x
-> m (Either (s (Coroutine s m x)) x)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (s (Coroutine s m x) -> Either (s (Coroutine s m x)) x
forall a b. a -> Either a b
Left s (Coroutine s m x)
s))
{-# INLINE suspend #-}


-- | Suspend the current coroutine until a value is provided.
await :: Monad m => Coroutine (Await x) m x
await :: forall (m :: * -> *) x. Monad m => Coroutine (Await x) m x
await = Await x (Coroutine (Await x) m x) -> Coroutine (Await x) m x
forall (m :: * -> *) (s :: * -> *) x.
Monad m =>
s (Coroutine s m x) -> Coroutine s m x
suspend ((x -> Coroutine (Await x) m x) -> Await x (Coroutine (Await x) m x)
forall x y. (x -> y) -> Await x y
Await x -> Coroutine (Await x) m x
forall a. a -> Coroutine (Await x) m a
forall (m :: * -> *) a. Monad m => a -> m a
return)


supply :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
supply :: forall (m :: * -> *) (s :: * -> *) x.
Monad m =>
(s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
supply s (Coroutine s m x) -> Coroutine s m x
runStep = Coroutine s m x -> m x
loop
  where
    loop :: Coroutine s m x -> m x
loop Coroutine s m x
c = Coroutine s m x -> m (Either (s (Coroutine s m x)) x)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume Coroutine s m x
c m (Either (s (Coroutine s m x)) x)
-> (Either (s (Coroutine s m x)) x -> m x) -> m x
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (s (Coroutine s m x) -> m x)
-> (x -> m x) -> Either (s (Coroutine s m x)) x -> m x
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Coroutine s m x -> m x
loop (Coroutine s m x -> m x)
-> (s (Coroutine s m x) -> Coroutine s m x)
-> s (Coroutine s m x)
-> m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s (Coroutine s m x) -> Coroutine s m x
runStep) x -> m x
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure


-- | Runs a suspendable 'Coroutine' to its completion with a monadic action.
supplyM :: Monad m => (s (Coroutine s m x) -> m (Coroutine s m x)) -> Coroutine s m x -> m x
supplyM :: forall (m :: * -> *) (s :: * -> *) x.
Monad m =>
(s (Coroutine s m x) -> m (Coroutine s m x))
-> Coroutine s m x -> m x
supplyM s (Coroutine s m x) -> m (Coroutine s m x)
runStep = Coroutine s m x -> m x
loop
  where
    loop :: Coroutine s m x -> m x
loop Coroutine s m x
c = Coroutine s m x -> m (Either (s (Coroutine s m x)) x)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume Coroutine s m x
c m (Either (s (Coroutine s m x)) x)
-> (Either (s (Coroutine s m x)) x -> m x) -> m x
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (s (Coroutine s m x) -> m x)
-> (x -> m x) -> Either (s (Coroutine s m x)) x -> m x
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (s (Coroutine s m x) -> m (Coroutine s m x)
runStep (s (Coroutine s m x) -> m (Coroutine s m x))
-> (Coroutine s m x -> m x) -> s (Coroutine s m x) -> m x
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Coroutine s m x -> m x
loop) x -> m x
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure


coroutineHoist
  :: forall f m n a
   . (Functor f, Monad m, Monad n)
  => (forall b. m b -> n b)
  -> Coroutine f m a
  -> Coroutine f n a
coroutineHoist :: forall (f :: * -> *) (m :: * -> *) (n :: * -> *) a.
(Functor f, Monad m, Monad n) =>
(forall b. m b -> n b) -> Coroutine f m a -> Coroutine f n a
coroutineHoist forall b. m b -> n b
f Coroutine f m a
routine =
  Coroutine
    { resume :: n (Either (f (Coroutine f n a)) a)
resume = (Either (f (Coroutine f m a)) a -> Either (f (Coroutine f n a)) a)
-> n (Either (f (Coroutine f m a)) a)
-> n (Either (f (Coroutine f n a)) a)
forall a b. (a -> b) -> n a -> n b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Either (f (Coroutine f m a)) a -> Either (f (Coroutine f n a)) a
go (n (Either (f (Coroutine f m a)) a)
 -> n (Either (f (Coroutine f n a)) a))
-> n (Either (f (Coroutine f m a)) a)
-> n (Either (f (Coroutine f n a)) a)
forall a b. (a -> b) -> a -> b
$ m (Either (f (Coroutine f m a)) a)
-> n (Either (f (Coroutine f m a)) a)
forall b. m b -> n b
f (m (Either (f (Coroutine f m a)) a)
 -> n (Either (f (Coroutine f m a)) a))
-> m (Either (f (Coroutine f m a)) a)
-> n (Either (f (Coroutine f m a)) a)
forall a b. (a -> b) -> a -> b
$ Coroutine f m a -> m (Either (f (Coroutine f m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume Coroutine f m a
routine
    }
  where
    go :: Either (f (Coroutine f m a)) a -> Either (f (Coroutine f n a)) a
go (Right a
r) = a -> Either (f (Coroutine f n a)) a
forall a b. b -> Either a b
Right a
r
    go (Left f (Coroutine f m a)
s) = f (Coroutine f n a) -> Either (f (Coroutine f n a)) a
forall a b. a -> Either a b
Left ((forall b. m b -> n b) -> Coroutine f m a -> Coroutine f n a
forall (f :: * -> *) (m :: * -> *) (n :: * -> *) a.
(Functor f, Monad m, Monad n) =>
(forall b. m b -> n b) -> Coroutine f m a -> Coroutine f n a
coroutineHoist m b -> n b
forall b. m b -> n b
f (Coroutine f m a -> Coroutine f n a)
-> f (Coroutine f m a) -> f (Coroutine f n a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Coroutine f m a)
s)