From 9b1df01fd2cf3ecf41fc68b94051db665821c774 Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Wed, 4 Aug 2021 11:09:35 -0400 Subject: Reimplement Que with Servant Still todo: add authentication. But that can wait. In re-implementing this, I was able to figure out how to get the Go.mult working properly as well. The problem is that a tap from a mult channel does not remove the message from the original channel. I'm not sure if that should be a core feature or not; for now I'm just draining the channel when it's received in the Que HTTP handler. (Also, this would be a good place to put persistence: have a background job read from the original channel, and write the msg to disk via acid-state; this would obviate the need for a flush to nowhere.) Also, streaming is working now. The problem was that Scotty closes the connection after it sees a newline in the body, or something, so streaming over Scotty doesn't actually work. It's fine, Servant is better anyway. --- Biz/Bild/Deps/Haskell.nix | 1 - Biz/Devalloc.hs | 13 +- Biz/Log.hs | 14 ++ Biz/Que/Client.py | 93 +++++++----- Biz/Que/Host.hs | 311 +++++++++++++++++++--------------------- Biz/Que/Site.hs | 6 +- Biz/Test.hs | 3 + Control/Concurrent/Go.hs | 117 ++++++++++++++- Network/Wai/Middleware/Braid.hs | 239 ++++++++++++++++++++++++++++++ 9 files changed, 577 insertions(+), 220 deletions(-) create mode 100644 Network/Wai/Middleware/Braid.hs diff --git a/Biz/Bild/Deps/Haskell.nix b/Biz/Bild/Deps/Haskell.nix index d2e6557..3077182 100644 --- a/Biz/Bild/Deps/Haskell.nix +++ b/Biz/Bild/Deps/Haskell.nix @@ -45,7 +45,6 @@ with hpkgs; regex-applicative req safecopy - scotty servant servant-auth servant-auth-server diff --git a/Biz/Devalloc.hs b/Biz/Devalloc.hs index b30bac4..998260e 100644 --- a/Biz/Devalloc.hs +++ b/Biz/Devalloc.hs @@ -714,18 +714,7 @@ tidy :: Config -> IO () tidy Config {..} = Directory.removeDirectoryRecursive keep run :: (Config, Wai.Application, Acid.AcidState Keep) -> IO () -run (cfg, app, _) = Warp.run (port cfg) (logMiddleware app) - -logMiddleware :: Wai.Middleware -logMiddleware app req sendResponse = - app req <| \res -> - Log.info - [ str <| Wai.requestMethod req, - show <| Wai.remoteHost req, - str <| Wai.rawPathInfo req - ] - >> Log.br - >> sendResponse res +run (cfg, app, _) = Warp.run (port cfg) (Log.wai app) liveCookieSettings :: Auth.CookieSettings liveCookieSettings = diff --git a/Biz/Log.hs b/Biz/Log.hs index 747efed..9304cf7 100644 --- a/Biz/Log.hs +++ b/Biz/Log.hs @@ -13,6 +13,8 @@ module Biz.Log -- Operators (~&), (~?), + -- Wai Middleware + wai, -- | Low-level msg, br, @@ -21,6 +23,7 @@ where import Alpha hiding (pass) import qualified Data.Text as Text +import qualified Network.Wai as Wai import Rainbow (chunk, fore, green, magenta, red, white, yellow) import qualified Rainbow import qualified System.Environment as Env @@ -87,3 +90,14 @@ mark label val = -- | Conditional mark. (~?) :: Show a => a -> (a -> Bool) -> Text -> a (~?) val test label = if test val then mark label val else val + +wai :: Wai.Middleware +wai app req sendResponse = + app req <| \res -> + info + [ str <| Wai.requestMethod req, + show <| Wai.remoteHost req, + str <| Wai.rawPathInfo req + ] + >> br + >> sendResponse res diff --git a/Biz/Que/Client.py b/Biz/Que/Client.py index 1063eb8..58877bf 100755 --- a/Biz/Que/Client.py +++ b/Biz/Que/Client.py @@ -11,11 +11,15 @@ import logging import os import subprocess import sys +import textwrap import time import urllib.parse import urllib.request as request -MAX_TIMEOUT = 99999999 # basically never timeout +MAX_TIMEOUT = 9999999 +RETRIES = 10 +DELAY = 3 +BACKOFF = 1 def auth(args): @@ -33,8 +37,8 @@ def auth(args): def autodecode(bytestring): - """Attempt to decode bytes `bs` into common codecs, preferably utf-8. If - no decoding is available, just return the raw bytes. + """Attempt to decode bytes into common codecs, preferably utf-8. If no + decoding is available, just return the raw bytes. For all available codecs, see: @@ -50,7 +54,7 @@ def autodecode(bytestring): return bytestring -def retry(exception, tries=4, delay=3, backoff=2): +def retry(exception, tries=RETRIES, delay=DELAY, backoff=BACKOFF): "Decorator for retrying an action." def decorator(func): @@ -73,20 +77,23 @@ def retry(exception, tries=4, delay=3, backoff=2): return decorator +@retry(urllib.error.URLError) +@retry(http.client.IncompleteRead) +@retry(http.client.RemoteDisconnected) def send(args): "Send a message to the que." logging.debug("send") key = auth(args) data = args.infile req = request.Request(f"{args.host}/{args.target}") - req.add_header("User-AgenT", "Que/Client") + req.add_header("User-Agent", "Que/Client") + req.add_header("Content-Type", "text/plain;charset=utf-8") if key: req.add_header("Authorization", key) if args.serve: logging.debug("serve") while not time.sleep(1): request.urlopen(req, data=data, timeout=MAX_TIMEOUT) - else: request.urlopen(req, data=data, timeout=MAX_TIMEOUT) @@ -96,75 +103,89 @@ def then(args, msg): if args.then: logging.debug("then") subprocess.run( - args.then.format(msg=msg, que=args.target), check=False, shell=True, + args.then.format(msg=msg, que=args.target), + check=False, + shell=True, ) -@retry(http.client.IncompleteRead, tries=10, delay=5, backoff=1) -@retry(http.client.RemoteDisconnected, tries=10, delay=2, backoff=2) +@retry(urllib.error.URLError) +@retry(http.client.IncompleteRead) +@retry(http.client.RemoteDisconnected) def recv(args): "Receive a message from the que." logging.debug("recv on: %s", args.target) - params = urllib.parse.urlencode({"poll": args.poll}) - req = request.Request(f"{args.host}/{args.target}?{params}") + if args.poll: + req = request.Request(f"{args.host}/{args.target}/stream") + else: + req = request.Request(f"{args.host}/{args.target}") req.add_header("User-Agent", "Que/Client") key = auth(args) if key: req.add_header("Authorization", key) with request.urlopen(req) as _req: if args.poll: - logging.debug("poll") + logging.debug("polling") while not time.sleep(1): - logging.debug("reading") - msg = autodecode(_req.readline()) - logging.debug("read") - print(msg, end="") - then(args, msg) + reply =_req.readline() + if reply: + msg = autodecode(reply) + logging.debug("read") + print(msg, end="") + then(args, msg) + else: + continue else: - msg = autodecode(_req.read()) + msg = autodecode(_req.readline()) print(msg) then(args, msg) def get_args(): "Command line parser" - cli = argparse.ArgumentParser(description=__doc__) + cli = argparse.ArgumentParser( + description=__doc__, + epilog=textwrap.dedent( + f"""Requests will retry up to {RETRIES} times, with {DELAY} seconds + between attempts.""" + ), + ) cli.add_argument("--debug", action="store_true", help="log to stderr") cli.add_argument( "--host", default="http://que.run", help="where que-server is running" ) cli.add_argument( - "--poll", default=False, action="store_true", help="stream data from the que" + "--poll", + default=False, + action="store_true", + help=textwrap.dedent( + """keep the connection open to stream data from the que. without + this flag, the program will exit after receiving a message""" + ), ) cli.add_argument( "--then", - help=" ".join( - [ - "when polling, run this shell command after each response,", - "presumably for side effects," - r"replacing '{que}' with the target and '{msg}' with the body of the response", - ] + help=textwrap.dedent( + """when polling, run this shell command after each response, + presumably for side effects, replacing '{que}' with the target and + '{msg}' with the body of the response""" ), ) cli.add_argument( "--serve", default=False, action="store_true", - help=" ".join( - [ - "when posting to the que, do so continuously in a loop.", - "this can be used for serving a webpage or other file continuously", - ] + help=textwrap.dedent( + """when posting to the que, do so continuously in a loop. this can + be used for serving a webpage or other file continuously""" ), ) - cli.add_argument( - "target", help="namespace and path of the que, like 'ns/path/subpath'" - ) + cli.add_argument("target", help="namespace and path of the que, like 'ns/path'") cli.add_argument( "infile", nargs="?", type=argparse.FileType("rb"), - help="data to put on the que. Use '-' for stdin, otherwise should be a readable file", + help="data to put on the que. use '-' for stdin, otherwise should be a readable file", ) return cli.parse_args() @@ -173,7 +194,7 @@ if __name__ == "__main__": ARGV = get_args() if ARGV.debug: logging.basicConfig( - format="%(asctime)s %(message)s", + format="%(asctime)s: %(levelname)s: %(message)s", level=logging.DEBUG, datefmt="%Y.%m.%d..%H.%M.%S", ) diff --git a/Biz/Que/Host.hs b/Biz/Que/Host.hs index 40ee1a5..702827e 100644 --- a/Biz/Que/Host.hs +++ b/Biz/Que/Host.hs @@ -1,8 +1,15 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE NoImplicitPrelude #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} -- | Interprocess communication -- @@ -20,29 +27,22 @@ where import Alpha hiding (gets, modify, poll) import qualified Biz.Cli as Cli +import qualified Biz.Log as Log import Biz.Test ((@=?)) import qualified Biz.Test as Test import qualified Control.Concurrent.Go as Go import qualified Control.Concurrent.STM as STM import qualified Control.Exception as Exception -import Control.Monad.Reader (MonadTrans) -import qualified Data.ByteString.Builder.Extra as Builder -import qualified Data.ByteString.Lazy as BSL import Data.HashMap.Lazy (HashMap) import qualified Data.HashMap.Lazy as HashMap -import qualified Data.Text.Encoding as Encoding -import qualified Data.Text.Lazy as Text.Lazy -import qualified Network.HTTP.Types.Status as Http -import qualified Network.Wai as Wai +import Network.HTTP.Media ((//), (/:)) import qualified Network.Wai.Handler.Warp as Warp -import Network.Wai.Middleware.RequestLogger - ( logStdout, - ) +import Servant +import Servant.API.Generic ((:-)) +-- import qualified Servant.Auth.Server as Auth +import Servant.Server.Generic (AsServerT, genericServeT) +import qualified Servant.Types.SourceT as Source import qualified System.Envy as Envy -import qualified Web.Scotty.Trans as Scotty -import qualified Prelude - -{-# ANN module ("HLint: ignore Reduce duplication" :: Prelude.String) #-} main :: IO () main = Cli.main <| Cli.Plan help move test pure @@ -51,15 +51,17 @@ move :: Cli.Arguments -> IO () move _ = Exception.bracket startup shutdown <| uncurry Warp.run where startup = - Envy.decodeWithDefaults Envy.defConfig +> \c -> do - sync <- STM.newTVarIO initialAppState - let runActionToIO m = runReaderT (runApp m) sync - waiapp <- Scotty.scottyAppT runActionToIO <| routes c - putText "*" - putText "que" - putText <| "port: " <> (show <| quePort c) - putText <| "skey: " <> (Text.Lazy.toStrict <| queSkey c) - return (quePort c, waiapp) + Envy.decodeWithDefaults Envy.defConfig +> \cfg@Config {..} -> do + initialState <- atomically <| STM.newTVar mempty + -- natural transformation + let nt :: AppState -> App a -> Servant.Handler a + nt s x = runReaderT x s + let app :: AppState -> Application + app s = genericServeT (nt s) (paths cfg) + Log.info ["boot", "que"] >> Log.br + Log.info ["boot", "port", show <| quePort] >> Log.br + Log.info ["boot", "skey", queSkey] >> Log.br + pure (quePort, app initialState) shutdown :: a -> IO a shutdown = pure <. identity @@ -76,167 +78,149 @@ Usage: test :: Test.Tree test = Test.group "Biz.Que.Host" [Test.unit "id" <| 1 @=? (1 :: Integer)] -newtype App a = App - { runApp :: ReaderT (STM.TVar AppState) IO a - } - deriving - ( Applicative, - Functor, - Monad, - MonadIO, - MonadReader - (STM.TVar AppState) - ) - -newtype AppState = AppState - { ques :: HashMap Namespace Quebase - } +type App = ReaderT AppState Servant.Handler -initialAppState :: AppState -initialAppState = AppState {ques = mempty} +type Ques = HashMap Namespace Quebase + +type AppState = STM.TVar Ques data Config = Config { -- | QUE_PORT quePort :: Warp.Port, -- | QUE_SKEY - queSkey :: Text.Lazy.Text + queSkey :: Text } deriving (Generic, Show) instance Envy.DefConfig Config where - defConfig = Config 3000 "admin-key" + defConfig = Config 3001 "admin-key" instance Envy.FromEnv Config -routes :: Config -> Scotty.ScottyT Text.Lazy.Text App () -routes cfg = do - Scotty.middleware logStdout - let quepath = "^\\/([[:alnum:]_-]+)\\/([[:alnum:]._/-]*)$" - let namespace = "^\\/([[:alnum:]_-]+)\\/?$" -- matches '/ns' and '/ns/' but not '/ns/path' - - -- GET /index.html - Scotty.get (Scotty.literal "/index.html") <| Scotty.redirect "/_/index" - Scotty.get (Scotty.literal "/") <| Scotty.redirect "/_/index" - -- GET /_/dash - Scotty.get (Scotty.literal "/_/dash") <| do - authkey <- fromMaybe "" > Scotty.text "not allowed: _ is a reserved namespace" - >> Scotty.finish - ) - guardNs ns ["pub", "_"] - -- passed all auth checks - app <. modify <| upsertNamespace ns - q <- app <| que ns qp - qdata <- Scotty.body - _ <- liftIO <| Go.write q <| BSL.toStrict qdata - return () +-- | A simple HTML type. This recognizes "content-type: text/html" but doesn't +-- do any conversion, rendering, or sanitization like the +-- 'Servant.HTML.Lucid.HTML' type would do. +data HTML deriving (Typeable) + +instance Accept HTML where + contentTypes _ = "text" // "html" /: ("charset", "utf-8") :| ["text" // "html"] + +instance MimeRender HTML ByteString where + mimeRender _ x = str x + +instance MimeUnrender HTML Text where + mimeUnrender _ bs = Right <| str bs + +instance MimeUnrender OctetStream Text where + mimeUnrender _ bs = Right <| str bs + +instance MimeRender PlainText ByteString where + mimeRender _ bs = str bs + +instance MimeUnrender PlainText ByteString where + mimeUnrender _ bs = Right <| str bs + +data Paths path = Paths + { home :: + path + :- Get '[JSON] NoContent, + dash :: + path + :- "_" + :> "dash" + :> Get '[JSON] Ques, + getQue :: + path + :- Capture "ns" Text + :> Capture "quename" Text + :> Get '[PlainText, HTML, OctetStream] Message, + getStream :: + path + :- Capture "ns" Text + :> Capture "quename" Text + :> "stream" + :> StreamGet NoFraming OctetStream (SourceIO Message), + putQue :: + path + :- Capture "ns" Text + :> Capture "quepath" Text + :> ReqBody '[PlainText, HTML, OctetStream] Text + :> Post '[PlainText, HTML, OctetStream] NoContent + } + deriving (Generic) + +paths :: Config -> Paths (AsServerT App) +paths _ = + -- TODO revive authkey stuff + -- - read Authorization header, compare with queSkey + -- - Only allow my IP or localhost to publish to '_' namespace + Paths + { home = + throwError <| err301 {errHeaders = [("Location", "/_/index")]}, + dash = gets, + getQue = \ns qn -> do + guardNs ns ["pub", "_"] + modify <| upsertNamespace ns + q <- que ns qn + Go.mult q + |> liftIO + +> Go.tap + |> liftIO, + getStream = \ns qn -> do + guardNs ns ["pub", "_"] + modify <| upsertNamespace ns + q <- que ns qn + Go.mult q + |> liftIO + +> Go.tap + |> Source.fromAction (const False) -- peek chan instead of False? + |> pure, + putQue = \ns qp body -> do + guardNs ns ["pub", "_"] + modify <| upsertNamespace ns + q <- que ns qp + body + |> str + |> Go.write q + >> Go.read q -- flush the que, otherwise msgs never clear + |> liftIO + -- TODO: detect number of readers, respond with "sent to N readers" or + -- "no readers, msg lost" + >> pure NoContent + } -- | Given `guardNs ns whitelist`, if `ns` is not in the `whitelist` -- list, return a 405 error. -guardNs :: Text.Lazy.Text -> [Text.Lazy.Text] -> Scotty.ActionT Text.Lazy.Text App () +guardNs :: (Applicative a, MonadError ServerError a) => Text -> [Text] -> a () guardNs ns whitelist = when (not <| ns `elem` whitelist) <| do - Scotty.status Http.methodNotAllowed405 - Scotty.text - <| "not allowed: use 'pub' namespace or signup to protect '" - <> ns - <> "' at https://que.run" - Scotty.finish - --- | recover from a scotty-thrown exception. -(!:) :: - -- | action that might throw - Scotty.ActionT Text.Lazy.Text App a -> - -- | a function providing a default response instead - (Text.Lazy.Text -> Scotty.ActionT Text.Lazy.Text App a) -> - Scotty.ActionT Text.Lazy.Text App a -(!:) = Scotty.rescue - --- | Forever write the data from 'Que' to 'Wai.StreamingBody'. -streamQue :: Que -> Wai.StreamingBody -streamQue q write _ = loop q + throwError <| err405 {errBody = str msg} where - loop c = - Go.read c - +> (write <. Builder.byteStringInsert) - >> loop c + msg = + "not allowed: use 'pub' namespace or signup to protect '" + <> ns + <> "' at https://que.run" -- | Gets the thing from the Hashmap. Call's 'error' if key doesn't exist. grab :: (Eq k, Hashable k) => k -> HashMap k v -> v grab = flip (HashMap.!) --- | Inserts the namespace in 'AppState' if it doesn't exist. -upsertNamespace :: Namespace -> AppState -> AppState +-- | Inserts the namespace in 'Ques' if it doesn't exist. +upsertNamespace :: Namespace -> HashMap Namespace Quebase -> HashMap Namespace Quebase upsertNamespace ns as = - if HashMap.member ns (ques as) + if HashMap.member ns as then as - else as {ques = HashMap.insert ns mempty (ques as)} + else HashMap.insert ns mempty as -- | Inserts the que at the proper 'Namespace' and 'Quepath'. -insertQue :: Namespace -> Quepath -> Que -> AppState -> AppState -insertQue ns qp q as = as {ques = newQues} +insertQue :: Namespace -> Quepath -> Que -> HashMap Namespace Quebase -> HashMap Namespace Quebase +insertQue ns qp q hm = newQues where - newQues = HashMap.insert ns newQbase (ques as) - newQbase = HashMap.insert qp q <| grab ns <| ques as - -extract :: Scotty.ActionT Text.Lazy.Text App (Namespace, Quepath) -extract = do - ns <- Scotty.param "1" - path <- Scotty.param "2" - return (ns, path) - --- | A synonym for 'lift' in order to be explicit about when we are --- operating at the 'App' layer. -app :: MonadTrans t => App a -> t App a -app = lift - --- | Get something from the app state -gets :: (AppState -> b) -> App b -gets f = ask +> liftIO <. STM.readTVarIO +> return AppState) -> App () -modify f = ask +> liftIO <. atomically <. flip STM.modifyTVar' f + newQues = HashMap.insert ns newQbase hm + newQbase = HashMap.insert qp q <| grab ns hm -- | housing for a set of que paths -type Namespace = Text.Lazy.Text +type Namespace = Text -- | a que is just a channel of bytes type Que = Go.Channel Message @@ -250,15 +234,22 @@ type Message = ByteString -- | a collection of ques type Quebase = HashMap Quepath Que +-- | Get app state +gets :: App Ques +gets = ask +> STM.readTVarIO .> liftIO +> pure + +-- | Apply a function to the app state +modify :: (Ques -> Ques) -> App () +modify f = ask +> flip STM.modifyTVar' f .> atomically .> liftIO + -- | Lookup or create a que que :: Namespace -> Quepath -> App Que que ns qp = do - _ques <- gets ques - let qbase = grab ns _ques - queExists = HashMap.member qp qbase - if queExists - then return <| grab qp qbase + ques <- gets + let qbase = grab ns ques + if HashMap.member qp qbase + then pure <| grab qp qbase else do - c <- liftIO <| Go.chan 1 + c <- liftIO <| Go.chan 5 modify (insertQue ns qp c) - gets ques /> grab ns /> grab qp + gets /> grab ns /> grab qp diff --git a/Biz/Que/Site.hs b/Biz/Que/Site.hs index 06b86c8..e027717 100644 --- a/Biz/Que/Site.hs +++ b/Biz/Que/Site.hs @@ -20,12 +20,9 @@ import qualified Control.Concurrent.Async as Async import qualified Data.ByteString.Char8 as BS import qualified Data.Ini.Config as Config import qualified Data.Text as Text -import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.IO as Text import Network.HTTP.Req import qualified System.Directory as Directory -import System.Environment as Environment -import qualified System.Exit as Exit import System.FilePath (()) import qualified System.Process as Process @@ -138,7 +135,8 @@ serve Nothing p _ _ = panic <| "no auth key provided for ns: " <> p serve (Just key) ns path content = runReq defaultHttpConfig <| do let options = - header "Authorization" (encodeUtf8 key) <> responseTimeout maxBound + header "Content-type" "text/html;charset=utf-8" + -- header "Authorization" (encodeUtf8 key) <> responseTimeout maxBound _ <- req POST diff --git a/Biz/Test.hs b/Biz/Test.hs index bd1384e..31a8831 100644 --- a/Biz/Test.hs +++ b/Biz/Test.hs @@ -74,6 +74,7 @@ assertNotEqual preface notexpected actual = ++ "\n but got: " ++ show actual +-- | unexpectedValue @?!= actual (@?!=) :: (Eq a, Show a, HasCallStack) => -- | The not-expected value @@ -85,11 +86,13 @@ expected @?!= actual = assertNotEqual "" expected actual infixl 2 @?!= +-- | (@=?) :: (Eq a, Show a) => a -> a -> HUnit.Assertion a @=? b = a HUnit.@=? b infixl 2 @=? +-- | (@?=) :: (Eq a, Show a) => a -> a -> HUnit.Assertion a @?= b = a HUnit.@?= b diff --git a/Control/Concurrent/Go.hs b/Control/Concurrent/Go.hs index 5057bfe..a5eb2b7 100644 --- a/Control/Concurrent/Go.hs +++ b/Control/Concurrent/Go.hs @@ -1,4 +1,6 @@ +{-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE NoImplicitPrelude #-} -- | An EDSL to make working with concurrent in-process code a bit easier @@ -8,32 +10,59 @@ -- Golang and Clojure's core.async. -- -- \$example +-- : out go module Control.Concurrent.Go ( -- * Running and forking fork, -- * Channels Channel, + Mult, chan, read, write, mult, tap, + + -- * internal + sleep, + test, + main, ) where import Alpha +import qualified Biz.Cli as Cli +import Biz.Test ((@?=)) +import qualified Biz.Test as Test import qualified Control.Concurrent as Concurrent import qualified Control.Concurrent.Chan.Unagi.Bounded as Chan import qualified Data.Aeson as Aeson -import Data.Text (Text) import qualified System.IO.Unsafe as Unsafe +main :: IO () +main = + Cli.main + <| Cli.Plan + { Cli.help = help, + Cli.move = \_ -> pure (), + Cli.test = test, + Cli.tidy = pure + } + where + help = + [Cli.docopt| + go + +Usage: + go test + |] + -- | A standard channel. data Channel a = Channel - { _in :: Chan.InChan a, - _out :: Chan.OutChan a, - _size :: Int + { _in :: !(Chan.InChan a), + _out :: !(Chan.OutChan a), + _size :: !Int } instance Aeson.ToJSON (Channel a) where @@ -76,6 +105,10 @@ read = Chan.readChan <. _out write :: Channel a -> a -> IO Bool write = Chan.tryWriteChan <. _in +-- | Sleep for some number of milliseconds +sleep :: Int -> IO () +sleep n = threadDelay <| n * 1_000 + -- <|example -- -- A simple example from ghci: @@ -92,10 +125,10 @@ write = Chan.tryWriteChan <. _in -- >>> Go.write c "hi" -- >>> Go.read c -- "hi" --- >>> Go.fork --- >>> Go.fork <| forever <| Go.mult c +> Go.tap +> \t -> print ("one: " <> t) +-- m <- Go.mult +-- >>> Go.fork <| forever (Go.tap m +> \t -> print ("one: " <> t)) -- ThreadId 810 --- >>> Go.fork <| forever <| Go.mult c +> Go.tap +> \t -> print ("two: " <> t) +-- >>> Go.fork <| forever (Go.tap m +> \t -> print ("two: " <> t)) -- ThreadId 825 -- >>> Go.write c "test" -- "two: t"eosnte": @@ -103,3 +136,73 @@ write = Chan.tryWriteChan <. _in -- -- The text is garbled because the actions are happening concurrently and -- trying to serialize to write the output, but you get the idea. +-- +test :: Test.Tree +test = + Test.group + "Control.Concurrent.Go" + [ Test.unit "simple example" <| do + c <- chan 1 :: IO (Channel Text) + recv <- mult c + _ <- fork (forever (tap recv +> pure)) + ret <- write c "simple example" + True @?= ret, + Test.unit "simple MVar counter" <| do + counter <- newEmptyMVar + putMVar counter (0 :: Integer) + modifyMVar_ counter (pure <. (+ 1)) + modifyMVar_ counter (pure <. (+ 1)) + modifyMVar_ counter (pure <. (+ 1)) + r <- takeMVar counter + r @?= 3 + {- Why don't these work? + Test.unit "subscription counter" <| do + counter <- newEmptyMVar :: IO (MVar Integer) + putMVar counter 0 + let dec = modifyMVar_ counter (\x -> pure <| x -1) + let inc = modifyMVar_ counter (pure <. (+ 1)) + c <- chan 10 :: IO (Channel Bool) + c1 <- mult c + _ <- fork (forever (tap c1 +> bool dec inc)) + _ <- write c True + _ <- write c True + _ <- write c True + threadDelay 1 + r1 <- takeMVar counter + r1 @?= 3, + Test.unit "SPMC" <| do + out1 <- newEmptyMVar + out2 <- newEmptyMVar + putMVar out1 "init" + putMVar out2 "init" + c <- chan 10 :: IO (Channel Text) + c1 <- mult c + c2 <- mult c + _ <- fork <| forever (tap c1 +> swapMVar out1 >> pure ()) + _ <- fork <| forever (tap c2 +> swapMVar out2 >> pure ()) + _ <- write c "test1" + _ <- write c "test2" + threadDelay 1 + r1 <- takeMVar out1 + r2 <- takeMVar out2 + r1 @?= r2 + r1 @?= "test2", + Test.unit "Unagi SPMC" <| do + out1 <- newEmptyMVar + out2 <- newEmptyMVar + putMVar out1 "init" + putMVar out2 "init" + (i, _) <- Chan.newChan 10 :: IO (Chan.InChan Text, Chan.OutChan Text) + o1 <- Chan.dupChan i + o2 <- Chan.dupChan i + _ <- forkIO <| forever (Chan.readChan o1 +> swapMVar out1 >> pure ()) + _ <- forkIO <| forever (Chan.readChan o2 +> swapMVar out2 >> pure ()) + _ <- Chan.writeChan i "test1" + _ <- Chan.writeChan i "test2" + threadDelay 1 + r1 <- takeMVar out1 + r2 <- takeMVar out2 + r1 @?= r2 + r1 @?= "test2" + -} + ] diff --git a/Network/Wai/Middleware/Braid.hs b/Network/Wai/Middleware/Braid.hs new file mode 100644 index 0000000..f9832ac --- /dev/null +++ b/Network/Wai/Middleware/Braid.hs @@ -0,0 +1,239 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NoImplicitPrelude #-} + +module Network.Wai.Middleware.Braid + ( -- * Types + Update, + Topic, + + -- * Method helpers + isGetRequest, + isPutRequest, + isPatchRequest, + + -- * 209 Status variable + status209, + + -- * Header helpers & variables + hSub, + hVer, + hMerge, + hParents, + hPatch, + lookupHeader, + getSubscription, + hasSubscription, + getSubscriptionKeepAliveTime, + addSubscriptionHeader, + getVersion, + hasVersion, + addVersionHeader, + getMergeType, + hasMergeType, + addMergeTypeHeader, + getParents, + hasParents, + getPatches, + hasPatches, + + -- * Update helpers + requestToUpdate, + updateToBuilder, + + -- * Middleware + braidify, + subscriptionMiddleware, + versionMiddleware, + addPatchHeader, + + -- * Subscription helper + streamUpdates, + ) +where + +import Alpha +import qualified Data.ByteString as B +import Data.ByteString.Builder (Builder, byteString) +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as L +import qualified Data.CaseInsensitive as CI +import Network.HTTP.Types.Header (Header, HeaderName, RequestHeaders) +import Network.HTTP.Types.Method (methodGet, methodPatch, methodPut) +import Network.HTTP.Types.Status (Status, mkStatus) +import qualified Network.Wai as Wai +import Network.Wai.Middleware.AddHeaders (addHeaders) + +type Topic = [Text] + +data Update = -- | Updates are streamed from the server to subcribing client. + -- On a PUT request, the headers and request body are put into an Update and streamed to subscribing clients. + Update + { -- | The updateTopic is formed, from the request path + updateTopic :: [Text], + -- | The updateClient is an id generated by the client to prevent echo updates + -- https://github.com/braid-work/braid-spec/issues/72 + updateClient :: Maybe B.ByteString, + -- | The updateHeader are taken straight from the request headers + updateHeaders :: RequestHeaders, + -- | The updatePatches correspond to the request body + updatePatches :: L.ByteString + } + +isGetRequest, isPutRequest, isPatchRequest :: Wai.Request -> Bool +isGetRequest req = Wai.requestMethod req == methodGet +isPutRequest req = Wai.requestMethod req == methodPut +isPatchRequest req = Wai.requestMethod req == methodPatch + +-- | 209 Subscription is the new status code for subscriptions in braid +status209 :: Status +status209 = mkStatus 209 "Subscription" + +lookupHeader :: HeaderName -> [Header] -> Maybe B.ByteString +lookupHeader _ [] = Nothing +lookupHeader v ((n, s) : t) + | v == n = Just s + | otherwise = lookupHeader v t + +hSub :: HeaderName +hSub = "Subscribe" + +getSubscription :: Wai.Request -> Maybe B.ByteString +getSubscription req = lookupHeader hSub <| Wai.requestHeaders req + +getSubscriptionKeepAliveTime :: Wai.Request -> B.ByteString +getSubscriptionKeepAliveTime req = + let Just s = lookupHeader hSub <| Wai.requestHeaders req + in snd <| BC.breakSubstring "=" s + +hasSubscription :: Wai.Request -> Bool +hasSubscription req = isJust <| getSubscription req + +addSubscriptionHeader :: B.ByteString -> Wai.Response -> Wai.Response +addSubscriptionHeader s = + Wai.mapResponseHeaders + (\hs -> (hSub, s) : ("Cache-Control", "no-cache, no-transform") : hs) + +hVer :: HeaderName +hVer = "Version" + +getVersion :: Wai.Request -> Maybe B.ByteString +getVersion req = lookupHeader hVer <| Wai.requestHeaders req + +hasVersion :: Wai.Request -> Bool +hasVersion req = isJust <| getVersion req + +addVersionHeader :: B.ByteString -> Wai.Response -> Wai.Response +addVersionHeader s = Wai.mapResponseHeaders (\hs -> (hVer, s) : hs) + +hMerge :: HeaderName +hMerge = "Merge-Type" + +getMergeType :: Wai.Request -> Maybe B.ByteString +getMergeType req = lookupHeader hMerge <| Wai.requestHeaders req + +hasMergeType :: Wai.Request -> Bool +hasMergeType req = isJust <| getMergeType req + +addMergeTypeHeader :: B.ByteString -> Wai.Response -> Wai.Response +addMergeTypeHeader s = Wai.mapResponseHeaders (\hs -> (hMerge, s) : hs) + +hParents :: HeaderName +hParents = "Parents" + +getParents :: Wai.Request -> Maybe B.ByteString +getParents req = lookupHeader hParents <| Wai.requestHeaders req + +hasParents :: Wai.Request -> Bool +hasParents req = isJust <| getParents req + +hPatch :: HeaderName +hPatch = "Patches" + +getPatches :: Wai.Request -> Maybe B.ByteString +getPatches req = lookupHeader hPatch <| Wai.requestHeaders req + +hasPatches :: Wai.Request -> Bool +hasPatches req = isJust <| getPatches req + +-- | Forms an Update from a WAI Request +requestToUpdate :: Wai.Request -> L.ByteString -> Update +requestToUpdate req body = + Update + { updateTopic = Wai.pathInfo req, + updateClient = lookupHeader "Client" reqHeaders, + updateHeaders = + [ (x, y) + | (x, y) <- reqHeaders, + x `elem` [hSub, hVer, hMerge, hParents, hPatch, "Content-Type"] + ], + updatePatches = body + } + where + reqHeaders = Wai.requestHeaders req + +separator :: B.ByteString +separator = BC.pack ": " + +-- | Turns an Update (headers and patches) into a Builder to be streamed +-- Will return Nothing if the Topic we pass doesn't not match the updateTopic in the Update +-- Or returns Just builder, where builder has type Builder +updateToBuilder :: Topic -> Maybe B.ByteString -> Update -> Maybe Builder +updateToBuilder topic client (Update t c h p) + | t /= topic && c == client = Nothing + | otherwise = Just <| builder h p + where + builder :: RequestHeaders -> L.ByteString -> Builder + builder hs b = + hs + |> map (\(h_, v) -> CI.original h_ <> separator <> v) + |> B.intercalate "\n" + |> (\headers -> headers <> "\n\n" <> L.toStrict b) + |> byteString + +-- TODO: still needs mechanism to keep alive, i.e. keeping the response connection open +subscriptionMiddleware :: Chan Update -> Wai.Middleware +subscriptionMiddleware src = catchUpdate src <. modifyHeadersToSub <. modifyStatusTo209 + where + modifyHeadersToSub :: Wai.Middleware + modifyHeadersToSub app req respond = + case getSubscription req of + Just v -> app req <| respond <. addSubscriptionHeader v + Nothing -> app req respond + modifyStatusTo209 :: Wai.Middleware + modifyStatusTo209 = Wai.ifRequest hasSubscription <| Wai.modifyResponse <| Wai.mapResponseStatus (const status209) + -- NOTE: we're consuming the full request body, maybe there's a better way of doing this? idk + catchUpdate :: Chan Update -> Wai.Middleware + catchUpdate src_ = + Wai.ifRequest isPutRequest <| \app req res -> do + src' <- liftIO <| dupChan src_ + Wai.strictRequestBody req +> \b -> + writeChan src' <| requestToUpdate req b + app req res + +versionMiddleware :: Wai.Middleware +versionMiddleware app req respond = + case (getVersion req, isGetRequest req) of + (Just v, True) -> app req <| respond <. addVersionHeader v + _ -> app req respond + +addPatchHeader :: Wai.Middleware +addPatchHeader = Wai.ifRequest isPutRequest <| addHeaders [("Patches", "OK")] + +-- | +-- TODO: look into Chan vs BroadcastChan (https://github.com/merijn/broadcast-chan) +streamUpdates :: Chan Update -> Topic -> Maybe ByteString -> Wai.StreamingBody +streamUpdates chan topic client write flush = do + flush + src <- liftIO <| dupChan chan + fix <| \loop -> do + update <- readChan src + case updateToBuilder topic client update of + Just b -> write b >> flush >> loop + Nothing -> loop + +braidify :: Chan Update -> Wai.Middleware +braidify src = + subscriptionMiddleware src + <. versionMiddleware + <. addPatchHeader + <. addHeaders [("Range-Request-Allow-Methods", "PATCH, PUT"), ("Range-Request-Allow-Units", "json")] -- cgit v1.2.3