module Control.Monad.Trans.Fault
(
fault
, faulty
, MonadFault (..)
, MonadFaults
, FaultlessT (..)
, runFaultlessT
, FaultyT (..)
, runFaultyT
, newFaultController
, setFault
, resetFault
, printFaultController
, showFaultController
, askFaultController
, FaultController (..)
, FaultConfig (..)
, NewFault ()
, HasFault (..)
) where
import Control.Exception (Exception, SomeException (..),
throwIO)
import Control.Monad ((>=>))
import Control.Monad.Base (MonadBase (..))
import Control.Monad.Catch (MonadCatch, MonadThrow)
import Control.Monad.Except (MonadError)
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Logger (MonadLogger (..))
import Control.Monad.Reader (MonadReader (..), ReaderT (..))
import Control.Monad.State (MonadState)
import Control.Monad.Trans (MonadTrans (..))
import Control.Monad.Trans.Control (ComposeSt, MonadBaseControl (..),
MonadTransControl (..),
defaultLiftBaseWith,
defaultLiftWith, defaultRestoreM,
defaultRestoreT)
import Control.Monad.Trans.Identity (IdentityT (..))
import Control.Monad.Trans.Resource (MonadResource (..))
import Data.IORef
import Data.Kind (Constraint)
import Data.Proxy
import GHC.TypeLits
class Monad m => MonadFault (fault :: Symbol) m where
faultPrx :: Proxy fault -> m ()
fault :: forall (fault :: Symbol) m. MonadFault fault m => m ()
fault = faultPrx (Proxy @fault)
type family MonadFaults (faults :: [Symbol]) (m :: * -> *) :: Constraint where
MonadFaults '[] m = ()
MonadFaults (fault ': rest) m = (MonadFault fault m, MonadFaults rest m)
instance (Monad (t m), MonadTrans t, MonadFault fault m) => MonadFault fault (t m) where
faultPrx _ = lift (fault @fault)
faulty :: forall fault m a. MonadFault fault m => m a -> m a
faulty = (fault @fault *>)
newtype FaultlessT m a = FaultlessT { unFaultlessT :: IdentityT m a }
deriving ( Functor, Applicative, Monad, MonadIO
, MonadLogger, MonadError e, MonadState s
, MonadReader r, MonadCatch, MonadThrow
)
runFaultlessT :: FaultlessT m a -> m a
runFaultlessT = runIdentityT . unFaultlessT
instance Monad m => MonadFault fault (FaultlessT m) where
faultPrx _ = pure ()
data FaultConfig = FaultConfig (Maybe SomeException) deriving (Show)
data FaultController (faults :: [Symbol]) where
FCNil :: FaultController '[]
FCCons :: forall f rest. KnownSymbol f => Proxy f -> !(IORef FaultConfig) -> FaultController rest -> FaultController (f ': rest)
class NewFault faults where
newFaultController :: IO (FaultController faults)
instance NewFault '[] where
newFaultController = pure FCNil
instance (KnownSymbol f, NewFault rest) => NewFault (f ': rest) where
newFaultController = do
ioref <- newIORef (FaultConfig Nothing)
rest <- newFaultController @rest
pure $ FCCons Proxy ioref rest
class HasFault (f :: Symbol) faults where
getFaultConfig :: FaultController faults -> IO FaultConfig
setFaultConfig :: FaultConfig -> FaultController faults -> IO ()
setFault :: forall fault faults e
. (HasFault fault faults, Exception e)
=> e
-> FaultController faults
-> IO ()
setFault e fc = setFaultConfig @fault (FaultConfig $ Just $ SomeException e) fc
resetFault :: forall fault faults e
. (HasFault fault faults, Exception e)
=> FaultController faults
-> IO ()
resetFault fc = setFaultConfig @fault (FaultConfig Nothing) fc
instance HasFault goal (goal ': rest) where
getFaultConfig (FCCons _ ioref _) = readIORef ioref
setFaultConfig new (FCCons _ ioref _) = atomicWriteIORef ioref new
instance HasFault goal rest => HasFault goal (f ': rest) where
getFaultConfig (FCCons _ _ rest) = getFaultConfig @goal rest
setFaultConfig new (FCCons _ _ rest) = setFaultConfig @goal new rest
newtype FaultyT (faults :: [Symbol]) m a = FaultyT { unFaultyT :: ReaderT (FaultController faults) m a }
deriving ( Functor, Applicative, Monad, MonadIO
, MonadLogger, MonadError e, MonadState s
, MonadCatch, MonadThrow
)
runFaultyT :: FaultController faults -> FaultyT faults m a -> m a
runFaultyT controller = flip runReaderT controller . unFaultyT
askFaultController :: Monad m => FaultyT faults m (FaultController faults)
askFaultController = FaultyT ask
instance forall f faults m. (MonadIO m, HasFault f faults) => MonadFault f (FaultyT faults m) where
faultPrx _ = do
fc <- askFaultController
FaultConfig mException <- liftIO $ getFaultConfig @f fc
maybe (pure ()) (liftIO . throwIO) mException
showFaultController :: FaultController faults -> IO String
showFaultController = \case
FCNil -> pure "FCNil"
FCCons prx ioref rest -> do
fc <- readIORef ioref
restStr <- showFaultController rest
pure $ "(FCCons @" ++ show (symbolVal prx) ++ " " ++ show fc ++ " " ++ restStr ++ ")"
printFaultController :: FaultController faults -> IO ()
printFaultController = showFaultController >=> putStrLn
instance MonadBaseControl b m => MonadBaseControl b (FaultlessT m) where
type StM (FaultlessT m) a = ComposeSt FaultlessT m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
instance MonadTransControl FaultlessT where
type StT FaultlessT a = StT IdentityT a
liftWith = defaultLiftWith FaultlessT unFaultlessT
restoreT = defaultRestoreT FaultlessT
instance MonadBase b m => MonadBase b (FaultlessT m) where
liftBase = FaultlessT . liftBase
instance MonadTrans FaultlessT where
lift = FaultlessT . IdentityT
instance MonadResource m => MonadResource (FaultlessT m) where
liftResourceT = lift . liftResourceT
instance MonadBaseControl b m => MonadBaseControl b (FaultyT faults m) where
type StM (FaultyT faults m) a = ComposeSt (FaultyT faults) m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
instance MonadTransControl (FaultyT faults) where
type StT (FaultyT faults) a = StT (ReaderT (FaultController faults)) a
liftWith = defaultLiftWith FaultyT unFaultyT
restoreT = defaultRestoreT FaultyT
instance MonadBase b m => MonadBase b (FaultyT faults m) where
liftBase = FaultyT . liftBase
instance MonadTrans (FaultyT faults) where
lift = FaultyT . lift
instance MonadReader r m => MonadReader r (FaultyT faults m) where
ask = lift ask
local f (FaultyT (ReaderT rf)) = FaultyT $ ReaderT $ \r -> local f (rf r)
instance MonadResource m => MonadResource (FaultyT faults m) where
liftResourceT = lift . liftResourceT