This is useful for multitask environments (like the API), to avoid conflicts
57 lines
1.5 KiB
Python
57 lines
1.5 KiB
Python
import os
|
||
import time
|
||
import logging
|
||
import psycopg2
|
||
from typing import Generator, Any
|
||
from psycopg2.pool import ThreadedConnectionPool
|
||
from psycopg2.extensions import connection
|
||
from contextlib import contextmanager
|
||
from functools import lru_cache
|
||
|
||
from domain.exceptions import DatabaseConnectionError
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
_MIN_CONN = 1
|
||
_MAX_CONN = 10
|
||
|
||
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def get_pool() -> ThreadedConnectionPool:
|
||
return _create_pool(DB_URI)
|
||
|
||
|
||
def _create_pool(
|
||
uri: str, retries: int = 10, base_delay: float = 1.0
|
||
) -> ThreadedConnectionPool:
|
||
for attempt in range(retries):
|
||
try:
|
||
pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, uri)
|
||
log.info(
|
||
"Pool PostgreSQL initialisé (%d–%d connexions)", _MIN_CONN, _MAX_CONN
|
||
)
|
||
return pool
|
||
except psycopg2.OperationalError as e:
|
||
delay = min(base_delay * 2**attempt, 30.0)
|
||
log.warning(
|
||
"Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e
|
||
)
|
||
time.sleep(delay)
|
||
raise DatabaseConnectionError(
|
||
f"Impossible de se connecter après {retries} tentatives"
|
||
)
|
||
|
||
|
||
@contextmanager
|
||
def get_conn() -> Generator[connection, Any, None]:
|
||
pool = get_pool()
|
||
conn = pool.getconn() # type: ignore
|
||
try:
|
||
yield conn
|
||
except Exception:
|
||
conn.rollback() # type: ignore
|
||
raise
|
||
finally:
|
||
pool.putconn(conn) # type: ignore
|