diff options
Diffstat (limited to '')
| -rw-r--r-- | specs/Mock.hs | 126 |
1 files changed, 113 insertions, 13 deletions
diff --git a/specs/Mock.hs b/specs/Mock.hs index 00cb64a..4d7f3f2 100644 --- a/specs/Mock.hs +++ b/specs/Mock.hs @@ -1,7 +1,28 @@ module Mock where import Control.Monad (join) +import Control.Monad.IO.Class (MonadIO) +import Control.Monad.State (MonadState (state), StateT) +import Data.Default (Default (def)) +import Data.Typeable (Typeable) import Language.Haskell.TH +import Test.Hspec + +callTypeName :: Name +callTypeName = mkName "Call" + +generateMock :: [Name] -> Q [Dec] +generateMock typeClassNames = do + functions <- join <$> mapM typeclassFunctions typeClassNames + let deriveClause = DerivClause Nothing [] + let callDataDeclr = + DataD [] callTypeName [PlainTV (mkName "ret") ()] Nothing (toGadtCtor <$> functions) [deriveClause] + instances <- join <$> mapM createInstanceOnType typeClassNames + extras <- staticExtrasForMockCalls + pure $ [callDataDeclr] ++ extras ++ instances + +toMockCtorName :: Name -> Name +toMockCtorName = mkName . ("Mock_" ++) . nameBase typeclassFunctions :: Name -> Q [(Name, Type)] typeclassFunctions cls = do @@ -13,21 +34,100 @@ typeclassFunctions cls = do functionName (SigD name typ) = [(name, typ)] functionName _ = [] -generateMock :: [Name] -> Q [Dec] -generateMock clss = do - functions <- join <$> mapM typeclassFunctions clss - pure [DataD [] (mkName "Call") [] Nothing (toCtor <$> functions) [deriveClause]] +countArgsInType :: Type -> Int +countArgsInType (AppT (AppT ArrowT _) rest) = 1 + countArgsInType rest +countArgsInType _ = 0 + +toGadtCtor :: (Name, Type) -> Con +toGadtCtor (name, typ) = GadtC [toMockCtorName name] args retType where - deriveClause = DerivClause Nothing (ConT <$> [''Show, ''Eq]) + (args, retType) = extractArgsAndReturnTypes typ - toCtor :: (Name, Type) -> Con - toCtor (name, typ) = NormalC (toCtorName name) ((noBang,) <$> getArgs typ) +extractArgsAndReturnTypes :: Type -> ([BangType], Type) +extractArgsAndReturnTypes (AppT (AppT ArrowT typ) rest) = ((noBang, typ) : args, ret) + where + (args, ret) = extractArgsAndReturnTypes rest + noBang = Bang NoSourceUnpackedness NoSourceStrictness +extractArgsAndReturnTypes (AppT _ ret) = ([], AppT (ConT callTypeName) ret) +extractArgsAndReturnTypes ret = ([], AppT (ConT callTypeName) ret) - getArgs :: Type -> [Type] - getArgs (AppT (AppT ArrowT typ) rest) = typ : getArgs rest - getArgs _ = [] +createInstanceOnType :: Name -> Q [Dec] +createInstanceOnType name = + typeclassFunctions name >>= createInstance name - noBang = Bang NoSourceUnpackedness NoSourceStrictness +createInstance :: Name -> [(Name, Type)] -> Q [Dec] +createInstance name funcs = do + let ctx = AppT (ConT ''MonadIO) (VarT $ mkName "m") + let typ = AppT (ConT name) (ConT (mkName "TestM") `AppT` VarT (mkName "a") `AppT` VarT (mkName "m")) + funcDeclrs <- mapM (\(n, t) -> toInstanceMethodDef n $ countArgsInType t) funcs + let inst = InstanceD Nothing [ctx] typ funcDeclrs + pure [inst] + +toInstanceMethodDef :: Name -> Int -> Q Dec +toInstanceMethodDef name argCount = do + let argNames = mkName . ("a" ++) . show <$> [1 .. argCount] + let callExp = foldl AppE (ConE $ toMockCtorName name) (VarE <$> argNames) + bodyExp <- instanceMethod callExp + pure $ FunD name [Clause (VarP <$> argNames) (NormalB bodyExp) []] + +instanceMethod :: Exp -> Q Exp +instanceMethod mockExp = do + [e| + do + declarations <- gets (fmap unsafeUnpackDeclr . mockDeclarations) + let call = $(pure mockExp) + getMockValue declarations call <$ registerMockCall call + |] + +staticExtrasForMockCalls :: Q [Dec] +staticExtrasForMockCalls = + [d| + deriving instance Show ($(conT callTypeName) a) + + deriving instance Eq ($(conT callTypeName) a) + + data CallWrapper where + CallWrapper :: (Typeable a, Show a, Eq a, Eq ($(conT callTypeName) a)) => $(conT callTypeName) a -> CallWrapper + + data CallMockDeclaration where + CallMockDeclaration :: (Typeable a, Show a, Eq a, Eq ($(conT callTypeName) a)) => $(conT callTypeName) a -> a -> CallMockDeclaration + + deriving instance Show CallWrapper + + -- deriving instance Eq CallWrapper + instance Eq CallWrapper where + (CallWrapper a) == (CallWrapper b) = + case cast a of + Just a' -> a' == b + Nothing -> False + + data MockCalls = MockCalls {calls :: [CallWrapper], mockDeclarations :: [CallMockDeclaration]} + + newtype TestM ret m a = TestM {runTestM :: StateT MockCalls m a} + deriving (Functor, Applicative, Monad, MonadIO, MonadState MockCalls) + + runTestMWithMocks :: (MonadIO m) => TestM x m a -> m (a, MockCalls) + runTestMWithMocks action = runStateT (runTestM action) (MockCalls [] []) + + registerMockCall :: (MonadState MockCalls m, Typeable a, Show a, Eq a) => $(conT callTypeName) a -> m () + registerMockCall call = + void $ state (\mock -> ((), mock {calls = calls mock ++ [CallWrapper call]})) + + getMockValue :: (Typeable a, Default a, Show a, Eq a) => [($(conT callTypeName) a, a)] -> $(conT callTypeName) a -> a + getMockValue [] _ = def + getMockValue ((fn, ret) : _) call | call == fn = ret + getMockValue (_ : rest) call = getMockValue rest call + + unsafeUnpackDeclr :: CallMockDeclaration -> ($(conT callTypeName) a, a) + unsafeUnpackDeclr (CallMockDeclaration f r) = unsafeCoerce (f, r) + + mockReturns :: (MonadState MockCalls m, Typeable a, Show a, Eq a) => $(conT callTypeName) a -> a -> m () + mockReturns call ret = + state (\mock -> ((), mock {mockDeclarations = mockDeclarations mock ++ [CallMockDeclaration call ret]})) + + shouldHaveCalled :: (HasCallStack, Typeable a, Show a, Eq a) => MockCalls -> $(conT callTypeName) a -> Expectation + shouldHaveCalled mock call = calls mock `shouldContain` [CallWrapper call] - toCtorName :: Name -> Name - toCtorName = mkName . ("Mock_" ++) . nameBase + shouldContainCalls :: (HasCallStack, Typeable a, Show a, Eq a) => MockCalls -> [$(conT callTypeName) a] -> Expectation + shouldContainCalls mock ls = calls mock `shouldContain` (CallWrapper <$> ls) + |] |
