{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RecordWildCards #-} module PostgREST.AppState ( AppState , AuthResult(..) , destroy , getConfig , getSchemaCache , getMainThreadId , getPgVersion , getNextDelay , getNextListenerDelay , getTime , getJwtCache , getSocketREST , getSocketAdmin , init , initSockets , initWithPool , putNextListenerDelay , putSchemaCache , putPgVersion , putIsListenerOn , usePool , readInDbConfig , schemaCacheLoader , getObserver , isLoaded , isPending ) where import qualified Data.Aeson as JSON import qualified Data.Aeson.KeyMap as KM import qualified Data.ByteString.Char8 as BS import qualified Data.Cache as C import Data.Either.Combinators (whenLeft) import qualified Data.Text as T (unpack) import qualified Hasql.Pool as SQL import qualified Hasql.Pool.Config as SQL import qualified Hasql.Session as SQL import qualified Hasql.Transaction.Sessions as SQL import qualified Network.HTTP.Types.Status as HTTP import qualified Network.Socket as NS import qualified PostgREST.Error as Error import qualified PostgREST.Logger as Logger import qualified PostgREST.Metrics as Metrics import PostgREST.Observation import PostgREST.Version (prettyVersion) import System.TimeIt (timeItT) import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate, updateAction) import Control.Debounce import Control.Retry (RetryPolicy, RetryStatus (..), capDelay, exponentialBackoff, retrying, rsPreviousDelay) import Data.IORef (IORef, atomicWriteIORef, newIORef, readIORef) import Data.Time.Clock (UTCTime, getCurrentTime) import PostgREST.Config (AppConfig (..), addFallbackAppName, readAppConfig) import PostgREST.Config.Database (queryDbSettings, queryPgVersion, queryRoleSettings) import PostgREST.Config.PgVersion (PgVersion (..), minimumPgVersion) import PostgREST.SchemaCache (SchemaCache (..), querySchemaCache, showSummary) import PostgREST.SchemaCache.Identifiers (dumpQi) import PostgREST.Unix (createAndBindDomainSocket) import Data.Streaming.Network (bindPortTCP, bindRandomPortTCP) import Data.String (IsString (..)) import Protolude data AuthResult = AuthResult { authClaims :: KM.KeyMap JSON.Value , authRole :: BS.ByteString } data AppState = AppState -- | Database connection pool { statePool :: SQL.Pool -- | Database server version , statePgVersion :: IORef PgVersion -- | Schema cache , stateSchemaCache :: IORef (Maybe SchemaCache) -- | The schema cache status , stateSCacheStatus :: IORef SchemaCacheStatus -- | State of the LISTEN channel , stateIsListenerOn :: IORef Bool -- | starts the connection worker with a debounce , debouncedSCacheLoader :: IO () -- | Config that can change at runtime , stateConf :: IORef AppConfig -- | Time used for verifying JWT expiration , stateGetTime :: IO UTCTime -- | Used for killing the main thread in case a subthread fails , stateMainThreadId :: ThreadId -- | Keeps track of the next delay for db connection retry , stateNextDelay :: IORef Int -- | Keeps track of the next delay for the listener , stateNextListenerDelay :: IORef Int -- | JWT Cache , jwtCache :: C.Cache ByteString AuthResult -- | Network socket for REST API , stateSocketREST :: NS.Socket -- | Network socket for the admin UI , stateSocketAdmin :: Maybe NS.Socket -- | Observation handler , stateObserver :: ObservationHandler , stateLogger :: Logger.LoggerState , stateMetrics :: Metrics.MetricsState } -- | Schema cache status data SchemaCacheStatus = SCLoaded | SCPending deriving Eq type AppSockets = (NS.Socket, Maybe NS.Socket) init :: AppConfig -> IO AppState init conf@AppConfig{configLogLevel, configDbPoolSize} = do loggerState <- Logger.init metricsState <- Metrics.init configDbPoolSize let observer = liftA2 (>>) (Logger.observationLogger loggerState configLogLevel) (Metrics.observationMetrics metricsState) pool <- initPool conf observer (sock, adminSock) <- initSockets conf state' <- initWithPool (sock, adminSock) pool conf loggerState metricsState observer pure state' { stateSocketREST = sock, stateSocketAdmin = adminSock} initWithPool :: AppSockets -> SQL.Pool -> AppConfig -> Logger.LoggerState -> Metrics.MetricsState -> ObservationHandler -> IO AppState initWithPool (sock, adminSock) pool conf loggerState metricsState observer = do appState <- AppState pool <$> newIORef minimumPgVersion -- assume we're in a supported version when starting, this will be corrected on a later step <*> newIORef Nothing <*> newIORef SCPending <*> newIORef False <*> pure (pure ()) <*> newIORef conf <*> mkAutoUpdate defaultUpdateSettings { updateAction = getCurrentTime } <*> myThreadId <*> newIORef 0 <*> newIORef 1 <*> C.newCache Nothing <*> pure sock <*> pure adminSock <*> pure observer <*> pure loggerState <*> pure metricsState deb <- let decisecond = 100000 in mkDebounce defaultDebounceSettings { debounceAction = retryingSchemaCacheLoad appState , debounceFreq = decisecond , debounceEdge = leadingEdge -- runs the worker at the start and the end } return appState { debouncedSCacheLoader = deb} destroy :: AppState -> IO () destroy = destroyPool initSockets :: AppConfig -> IO AppSockets initSockets AppConfig{..} = do let cfg'usp = configServerUnixSocket cfg'uspm = configServerUnixSocketMode cfg'host = configServerHost cfg'port = configServerPort cfg'adminport = configAdminServerPort sock <- case cfg'usp of -- I'm not using `streaming-commons`' bindPath function here because it's not defined for Windows, -- but we need to have runtime error if we try to use it in Windows, not compile time error Just path -> createAndBindDomainSocket path cfg'uspm Nothing -> do (_, sock) <- if cfg'port /= 0 then do sock <- bindPortTCP cfg'port (fromString $ T.unpack cfg'host) pure (cfg'port, sock) else do -- explicitly bind to a random port, returning bound port number (num, sock) <- bindRandomPortTCP (fromString $ T.unpack cfg'host) pure (num, sock) pure sock adminSock <- case cfg'adminport of Just adminPort -> do adminSock <- bindPortTCP adminPort (fromString $ T.unpack cfg'host) pure $ Just adminSock Nothing -> pure Nothing pure (sock, adminSock) initPool :: AppConfig -> ObservationHandler -> IO SQL.Pool initPool AppConfig{..} observer = SQL.acquire $ SQL.settings [ SQL.size configDbPoolSize , SQL.acquisitionTimeout $ fromIntegral configDbPoolAcquisitionTimeout , SQL.agingTimeout $ fromIntegral configDbPoolMaxLifetime , SQL.idlenessTimeout $ fromIntegral configDbPoolMaxIdletime , SQL.staticConnectionSettings (toUtf8 $ addFallbackAppName prettyVersion configDbUri) , SQL.observationHandler $ observer . HasqlPoolObs ] -- | Run an action with a database connection. usePool :: AppState -> SQL.Session a -> IO (Either SQL.UsageError a) usePool AppState{stateObserver=observer, stateMainThreadId=mainThreadId, ..} sess = do observer PoolRequest res <- SQL.use statePool sess observer PoolRequestFullfilled whenLeft res (\case SQL.AcquisitionTimeoutUsageError -> observer $ PoolAcqTimeoutObs SQL.AcquisitionTimeoutUsageError err@(SQL.ConnectionUsageError e) -> let failureMessage = BS.unpack $ fromMaybe mempty e in when (("FATAL: password authentication failed" `isInfixOf` failureMessage) || ("no password supplied" `isInfixOf` failureMessage)) $ do observer $ ExitDBFatalError ServerAuthError err killThread mainThreadId err@(SQL.SessionUsageError (SQL.QueryError tpl _ (SQL.ResultError resultErr))) -> do case resultErr of SQL.UnexpectedResult{} -> do observer $ ExitDBFatalError ServerPgrstBug err killThread mainThreadId SQL.RowError{} -> do observer $ ExitDBFatalError ServerPgrstBug err killThread mainThreadId SQL.UnexpectedAmountOfRows{} -> do observer $ ExitDBFatalError ServerPgrstBug err killThread mainThreadId -- Check for a syntax error (42601 is the pg code) only for queries that don't have `WITH pgrst_source` as prefix. -- This would mean the error is on our schema cache queries, so we treat it as fatal. -- TODO have a better way to mark this as a schema cache query SQL.ServerError "42601" _ _ _ _ -> unless ("WITH pgrst_source" `BS.isPrefixOf` tpl) $ do observer $ ExitDBFatalError ServerPgrstBug err killThread mainThreadId -- Check for a "prepared statement already exists" error (Code 42P05: duplicate_prepared_statement). -- This would mean that a connection pooler in transaction mode is being used -- while prepared statements are enabled in the PostgREST configuration, -- both of which are incompatible with each other. SQL.ServerError "42P05" _ _ _ _ -> do observer $ ExitDBFatalError ServerError42P05 err killThread mainThreadId -- Check for a "transaction blocks not allowed in statement pooling mode" error (Code 08P01: protocol_violation). -- This would mean that a connection pooler in statement mode is being used which is not supported in PostgREST. SQL.ServerError "08P01" "transaction blocks not allowed in statement pooling mode" _ _ _ -> do observer $ ExitDBFatalError ServerError08P01 err killThread mainThreadId SQL.ServerError{} -> when (Error.status (Error.PgError False err) >= HTTP.status500) $ observer $ QueryErrorCodeHighObs err err@(SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ClientError _))) -> -- An error on the client-side, usually indicates problems wth connection observer $ QueryErrorCodeHighObs err ) return res -- | Flush the connection pool so that any future use of the pool will -- use connections freshly established after this call. flushPool :: AppState -> IO () flushPool AppState{..} = SQL.release statePool -- | Destroy the pool on shutdown. destroyPool :: AppState -> IO () destroyPool AppState{..} = SQL.release statePool getPgVersion :: AppState -> IO PgVersion getPgVersion = readIORef . statePgVersion putPgVersion :: AppState -> PgVersion -> IO () putPgVersion = atomicWriteIORef . statePgVersion getSchemaCache :: AppState -> IO (Maybe SchemaCache) getSchemaCache = readIORef . stateSchemaCache putSchemaCache :: AppState -> Maybe SchemaCache -> IO () putSchemaCache appState = atomicWriteIORef (stateSchemaCache appState) schemaCacheLoader :: AppState -> IO () schemaCacheLoader = debouncedSCacheLoader getNextDelay :: AppState -> IO Int getNextDelay = readIORef . stateNextDelay getNextListenerDelay :: AppState -> IO Int getNextListenerDelay = readIORef . stateNextListenerDelay putNextListenerDelay :: AppState -> Int -> IO () putNextListenerDelay = atomicWriteIORef . stateNextListenerDelay getConfig :: AppState -> IO AppConfig getConfig = readIORef . stateConf putConfig :: AppState -> AppConfig -> IO () putConfig = atomicWriteIORef . stateConf getTime :: AppState -> IO UTCTime getTime = stateGetTime getJwtCache :: AppState -> C.Cache ByteString AuthResult getJwtCache = jwtCache getSocketREST :: AppState -> NS.Socket getSocketREST = stateSocketREST getSocketAdmin :: AppState -> Maybe NS.Socket getSocketAdmin = stateSocketAdmin getMainThreadId :: AppState -> ThreadId getMainThreadId = stateMainThreadId isConnEstablished :: AppState -> IO Bool isConnEstablished appState = do AppConfig{..} <- getConfig appState if configDbChannelEnabled then -- if the listener is enabled, we can be sure the connection is up readIORef $ stateIsListenerOn appState else -- otherwise the only way to check the connection is to make a query isRight <$> usePool appState (SQL.sql "SELECT 1") putIsListenerOn :: AppState -> Bool -> IO () putIsListenerOn = atomicWriteIORef . stateIsListenerOn isLoaded :: AppState -> IO Bool isLoaded x = do scacheStatus <- readIORef $ stateSCacheStatus x connEstablished <- isConnEstablished x return $ scacheStatus == SCLoaded && connEstablished isPending :: AppState -> IO Bool isPending x = do scacheStatus <- readIORef $ stateSCacheStatus x connEstablished <- isConnEstablished x return $ scacheStatus == SCPending || not connEstablished putSCacheStatus :: AppState -> SchemaCacheStatus -> IO () putSCacheStatus = atomicWriteIORef . stateSCacheStatus getObserver :: AppState -> ObservationHandler getObserver = stateObserver -- | Try to load the schema cache and retry if it fails. -- -- This is done by repeatedly: 1) flushing the pool, 2) querying the version and validating that the postgres version is supported by us, and 3) loading the schema cache. -- It's necessary to flush the pool: -- -- + Because connections cache the pg catalog(see #2620) -- + For rapid recovery. Otherwise, the pool idle or lifetime timeout would have to be reached for new healthy connections to be acquired. retryingSchemaCacheLoad :: AppState -> IO () retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId} = void $ retrying retryPolicy shouldRetry (\RetryStatus{rsIterNumber, rsPreviousDelay} -> do when (rsIterNumber > 0) $ do let delay = fromMaybe 0 rsPreviousDelay `div` oneSecondInUs observer $ ConnectionRetryObs delay putNextListenerDelay appState delay flushPool appState (,) <$> qPgVersion <*> (qInDbConfig *> qSchemaCache) ) where qPgVersion :: IO (Maybe PgVersion) qPgVersion = do AppConfig{..} <- getConfig appState pgVersion <- usePool appState (queryPgVersion False) -- No need to prepare the query here, as the connection might not be established case pgVersion of Left e -> do observer $ QueryPgVersionError e unless configDbPoolAutomaticRecovery $ do observer ExitDBNoRecoveryObs killThread mainThreadId return Nothing Right actualPgVersion -> do when (actualPgVersion < minimumPgVersion) $ do observer $ ExitUnsupportedPgVersion actualPgVersion minimumPgVersion killThread mainThreadId observer $ DBConnectedObs $ pgvFullName actualPgVersion putPgVersion appState actualPgVersion return $ Just actualPgVersion qInDbConfig :: IO () qInDbConfig = do AppConfig{..} <- getConfig appState when configDbConfig $ readInDbConfig False appState qSchemaCache :: IO (Maybe SchemaCache) qSchemaCache = do conf@AppConfig{..} <- getConfig appState (resultTime, result) <- let transaction = if configDbPreparedStatements then SQL.transaction else SQL.unpreparedTransaction in timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache conf) case result of Left e -> do putSCacheStatus appState SCPending putSchemaCache appState Nothing observer $ SchemaCacheErrorObs e return Nothing Right sCache -> do -- IMPORTANT: While the pending schema cache state starts from running the above querySchemaCache, only at this stage we block API requests due to the usage of an -- IORef on putSchemaCache. This is why SCacheStatus is put at SCPending here to signal the Admin server (using isPending) that we're on a recovery state. putSCacheStatus appState SCPending putSchemaCache appState $ Just sCache observer $ SchemaCacheQueriedObs resultTime (t, _) <- timeItT $ observer $ SchemaCacheSummaryObs $ showSummary sCache observer $ SchemaCacheLoadedObs t putSCacheStatus appState SCLoaded return $ Just sCache shouldRetry :: RetryStatus -> (Maybe PgVersion, Maybe SchemaCache) -> IO Bool shouldRetry _ (pgVer, sCache) = do AppConfig{..} <- getConfig appState let itShould = configDbPoolAutomaticRecovery && (isNothing pgVer || isNothing sCache) return itShould retryPolicy :: RetryPolicy retryPolicy = let delayMicroseconds = 32*oneSecondInUs {-32 seconds-} in capDelay delayMicroseconds $ exponentialBackoff oneSecondInUs oneSecondInUs = 1000000 -- one second in microseconds -- | Reads the in-db config and reads the config file again -- | We don't retry reading the in-db config after it fails immediately, because it could have user errors. We just report the error and continue. readInDbConfig :: Bool -> AppState -> IO () readInDbConfig startingUp appState@AppState{stateObserver=observer} = do AppConfig{..} <- getConfig appState pgVer <- getPgVersion appState dbSettings <- if configDbConfig then do qDbSettings <- usePool appState (queryDbSettings (dumpQi <$> configDbPreConfig) configDbPreparedStatements) case qDbSettings of Left e -> do observer $ ConfigReadErrorObs e pure mempty Right x -> pure x else pure mempty (roleSettings, roleIsolationLvl) <- if configDbConfig then do rSettings <- usePool appState (queryRoleSettings pgVer configDbPreparedStatements) case rSettings of Left e -> do observer $ QueryRoleSettingsErrorObs e pure (mempty, mempty) Right x -> pure x else pure mempty readAppConfig dbSettings configFilePath (Just configDbUri) roleSettings roleIsolationLvl >>= \case Left err -> if startingUp then panic err -- die on invalid config if the program is starting up else observer $ ConfigInvalidObs err Right newConf -> do putConfig appState newConf if startingUp then pass else observer ConfigSucceededObs