aboutsummaryrefslogtreecommitdiff
path: root/specs/Mock.hs
blob: 4d7f3f2583a86f197386f817209ca459858b22a6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
module Mock where

import Control.Monad (join)
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.State (MonadState (state), StateT)
import Data.Default (Default (def))
import Data.Typeable (Typeable)
import Language.Haskell.TH
import Test.Hspec

callTypeName :: Name
callTypeName = mkName "Call"

generateMock :: [Name] -> Q [Dec]
generateMock typeClassNames = do
  functions <- join <$> mapM typeclassFunctions typeClassNames
  let deriveClause = DerivClause Nothing []
  let callDataDeclr =
        DataD [] callTypeName [PlainTV (mkName "ret") ()] Nothing (toGadtCtor <$> functions) [deriveClause]
  instances <- join <$> mapM createInstanceOnType typeClassNames
  extras <- staticExtrasForMockCalls
  pure $ [callDataDeclr] ++ extras ++ instances

toMockCtorName :: Name -> Name
toMockCtorName = mkName . ("Mock_" ++) . nameBase

typeclassFunctions :: Name -> Q [(Name, Type)]
typeclassFunctions cls = do
  c <- reify cls
  pure $ case c of
    ClassI (ClassD _ _ _ _ declr) _ -> declr >>= functionName
    _ -> []
  where
    functionName (SigD name typ) = [(name, typ)]
    functionName _ = []

countArgsInType :: Type -> Int
countArgsInType (AppT (AppT ArrowT _) rest) = 1 + countArgsInType rest
countArgsInType _ = 0

toGadtCtor :: (Name, Type) -> Con
toGadtCtor (name, typ) = GadtC [toMockCtorName name] args retType
  where
    (args, retType) = extractArgsAndReturnTypes typ

extractArgsAndReturnTypes :: Type -> ([BangType], Type)
extractArgsAndReturnTypes (AppT (AppT ArrowT typ) rest) = ((noBang, typ) : args, ret)
  where
    (args, ret) = extractArgsAndReturnTypes rest
    noBang = Bang NoSourceUnpackedness NoSourceStrictness
extractArgsAndReturnTypes (AppT _ ret) = ([], AppT (ConT callTypeName) ret)
extractArgsAndReturnTypes ret = ([], AppT (ConT callTypeName) ret)

createInstanceOnType :: Name -> Q [Dec]
createInstanceOnType name =
  typeclassFunctions name >>= createInstance name

createInstance :: Name -> [(Name, Type)] -> Q [Dec]
createInstance name funcs = do
  let ctx = AppT (ConT ''MonadIO) (VarT $ mkName "m")
  let typ = AppT (ConT name) (ConT (mkName "TestM") `AppT` VarT (mkName "a") `AppT` VarT (mkName "m"))
  funcDeclrs <- mapM (\(n, t) -> toInstanceMethodDef n $ countArgsInType t) funcs
  let inst = InstanceD Nothing [ctx] typ funcDeclrs
  pure [inst]

toInstanceMethodDef :: Name -> Int -> Q Dec
toInstanceMethodDef name argCount = do
  let argNames = mkName . ("a" ++) . show <$> [1 .. argCount]
  let callExp = foldl AppE (ConE $ toMockCtorName name) (VarE <$> argNames)
  bodyExp <- instanceMethod callExp
  pure $ FunD name [Clause (VarP <$> argNames) (NormalB bodyExp) []]

instanceMethod :: Exp -> Q Exp
instanceMethod mockExp = do
  [e|
    do
      declarations <- gets (fmap unsafeUnpackDeclr . mockDeclarations)
      let call = $(pure mockExp)
      getMockValue declarations call <$ registerMockCall call
    |]

staticExtrasForMockCalls :: Q [Dec]
staticExtrasForMockCalls =
  [d|
    deriving instance Show ($(conT callTypeName) a)

    deriving instance Eq ($(conT callTypeName) a)

    data CallWrapper where
      CallWrapper :: (Typeable a, Show a, Eq a, Eq ($(conT callTypeName) a)) => $(conT callTypeName) a -> CallWrapper

    data CallMockDeclaration where
      CallMockDeclaration :: (Typeable a, Show a, Eq a, Eq ($(conT callTypeName) a)) => $(conT callTypeName) a -> a -> CallMockDeclaration

    deriving instance Show CallWrapper

    -- deriving instance Eq CallWrapper
    instance Eq CallWrapper where
      (CallWrapper a) == (CallWrapper b) =
        case cast a of
          Just a' -> a' == b
          Nothing -> False

    data MockCalls = MockCalls {calls :: [CallWrapper], mockDeclarations :: [CallMockDeclaration]}

    newtype TestM ret m a = TestM {runTestM :: StateT MockCalls m a}
      deriving (Functor, Applicative, Monad, MonadIO, MonadState MockCalls)

    runTestMWithMocks :: (MonadIO m) => TestM x m a -> m (a, MockCalls)
    runTestMWithMocks action = runStateT (runTestM action) (MockCalls [] [])

    registerMockCall :: (MonadState MockCalls m, Typeable a, Show a, Eq a) => $(conT callTypeName) a -> m ()
    registerMockCall call =
      void $ state (\mock -> ((), mock {calls = calls mock ++ [CallWrapper call]}))

    getMockValue :: (Typeable a, Default a, Show a, Eq a) => [($(conT callTypeName) a, a)] -> $(conT callTypeName) a -> a
    getMockValue [] _ = def
    getMockValue ((fn, ret) : _) call | call == fn = ret
    getMockValue (_ : rest) call = getMockValue rest call

    unsafeUnpackDeclr :: CallMockDeclaration -> ($(conT callTypeName) a, a)
    unsafeUnpackDeclr (CallMockDeclaration f r) = unsafeCoerce (f, r)

    mockReturns :: (MonadState MockCalls m, Typeable a, Show a, Eq a) => $(conT callTypeName) a -> a -> m ()
    mockReturns call ret =
      state (\mock -> ((), mock {mockDeclarations = mockDeclarations mock ++ [CallMockDeclaration call ret]}))

    shouldHaveCalled :: (HasCallStack, Typeable a, Show a, Eq a) => MockCalls -> $(conT callTypeName) a -> Expectation
    shouldHaveCalled mock call = calls mock `shouldContain` [CallWrapper call]

    shouldContainCalls :: (HasCallStack, Typeable a, Show a, Eq a) => MockCalls -> [$(conT callTypeName) a] -> Expectation
    shouldContainCalls mock ls = calls mock `shouldContain` (CallWrapper <$> ls)
    |]