feat: use connection pool instead of a single shared connection

This is useful for multitask environments (like the API), to avoid
conflicts
This commit is contained in:
Alexis Fourmaux 2026-05-10 14:55:50 +02:00
parent a22413bb0b
commit aa72971627
5 changed files with 95 additions and 62 deletions

View file

@ -1,64 +1,48 @@
import logging import logging
import time
import psycopg2 import psycopg2
from psycopg2.extensions import connection
from ports import DeviceRepository, ReadingRepository from ports import DeviceRepository, ReadingRepository
from domain.exceptions import DatabaseConnectionError, DatabaseError from domain.exceptions import DatabaseError
from infrastructure.db import get_conn
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def connect(uri: str, retries: int = 10, base_delay: float = 1.0) -> connection:
for attempt in range(retries):
try:
conn = psycopg2.connect(uri)
conn.autocommit = True
log.info("PostgreSQL connecté")
return conn
except psycopg2.OperationalError as e:
delay = min(base_delay * 2 ** attempt, 30.0)
log.warning("Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e)
time.sleep(delay)
raise DatabaseConnectionError(f"Impossible de se connecter après {retries} tentatives")
class PgDeviceRepository(DeviceRepository): class PgDeviceRepository(DeviceRepository):
def __init__(self, conn: connection):
self._conn = conn
def get_or_create_device_id(self, dev_eui: str) -> str: def get_or_create_device_id(self, dev_eui: str) -> str:
try: try:
with self._conn.cursor() as cur: with get_conn() as conn:
cur.execute( with conn.cursor() as cur:
""" cur.execute(
INSERT INTO device (device_eui) """
VALUES (%s) INSERT INTO device (device_eui)
ON CONFLICT (device_eui) DO NOTHING VALUES (%s)
""", ON CONFLICT (device_eui) DO NOTHING
(dev_eui,), """,
) (dev_eui,),
cur.execute( )
"SELECT device_id FROM device WHERE device_eui = %s", (dev_eui,) cur.execute(
) "SELECT device_id FROM device WHERE device_eui = %s", (dev_eui,)
return str(cur.fetchone()[0]) # type: ignore )
return str(cur.fetchone()[0]) # type: ignore
except psycopg2.DatabaseError as e: except psycopg2.DatabaseError as e:
raise DatabaseError(f"Erreur de création du device {dev_eui}") from e raise DatabaseError(f"Erreur de création du device {dev_eui}") from e
class PgReadingRepository(ReadingRepository): class PgReadingRepository(ReadingRepository):
def __init__(self, conn: connection):
self._conn = conn
def insert_reading(self, device_id: str, pulse_count: int) -> None: def insert_reading(self, device_id: str, pulse_count: int) -> None:
try: try:
with self._conn.cursor() as cur: with get_conn() as conn:
cur.execute( with conn.cursor() as cur:
""" cur.execute(
INSERT INTO reading (device_id, date, pulses) """
VALUES (%s, NOW(), %s) INSERT INTO reading (device_id, date, pulses)
""", VALUES (%s, NOW(), %s)
(device_id, pulse_count), """,
) (device_id, pulse_count),
)
except psycopg2.DatabaseError as e: except psycopg2.DatabaseError as e:
raise DatabaseError(f"Erreur d'enregistrement de la télérelève sur le device {device_id}") from e raise DatabaseError(
f"Erreur d'enregistrement de la télérelève sur le device {device_id}"
) from e

View file

@ -2,12 +2,12 @@ from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
import psycopg2 import psycopg2
from psycopg2.extensions import connection
from domain.entities import ConsumptionPoint from domain.entities import ConsumptionPoint
from domain.exceptions import DatabaseError from domain.exceptions import DatabaseError
from domain.value_objects import Granularity from domain.value_objects import Granularity
from ports.reading_query_repository import ReadingQueryRepository from ports.reading_query_repository import ReadingQueryRepository
from infrastructure.db import get_conn
_GRANULARITY_DELTA = { _GRANULARITY_DELTA = {
@ -19,9 +19,6 @@ _GRANULARITY_DELTA = {
class PgReadingQueryRepository(ReadingQueryRepository): class PgReadingQueryRepository(ReadingQueryRepository):
def __init__(self, conn: connection) -> None:
self._conn = conn
def get_consumption( def get_consumption(
self, self,
dev_eui: str, dev_eui: str,
@ -54,9 +51,10 @@ class PgReadingQueryRepository(ReadingQueryRepository):
ORDER BY period ASC ORDER BY period ASC
""" """
try: try:
with self._conn.cursor() as cur: with get_conn() as conn:
cur.execute(query, (date_trunc, dev_eui, adjusted_start, end)) with conn.cursor() as cur:
rows = cur.fetchall() cur.execute(query, (date_trunc, dev_eui, adjusted_start, end))
rows = cur.fetchall()
except psycopg2.DatabaseError as e: except psycopg2.DatabaseError as e:
raise DatabaseError(f"Erreur requête consumption : {e}") from e raise DatabaseError(f"Erreur requête consumption : {e}") from e

View file

@ -1,10 +1,9 @@
from core.logging import setup_logging from core.logging import setup_logging
from dependencies import get_conn, get_uplink_service, get_mqtt_broker from dependencies import get_uplink_service, get_mqtt_broker
setup_logging() setup_logging()
if __name__ == "__main__": if __name__ == "__main__":
conn = get_conn()
broker = get_mqtt_broker() broker = get_mqtt_broker()
uplink = get_uplink_service() uplink = get_uplink_service()

View file

@ -1,7 +1,6 @@
import os import os
from functools import lru_cache
from adapters.postgres import connect, PgDeviceRepository, PgReadingRepository from adapters.postgres import PgDeviceRepository, PgReadingRepository
from adapters.postgres_query import PgReadingQueryRepository from adapters.postgres_query import PgReadingQueryRepository
from adapters.mqtt import PahoMqttBroker from adapters.mqtt import PahoMqttBroker
from services.uplink_service import UplinkService from services.uplink_service import UplinkService
@ -10,21 +9,17 @@ from services.consumption_service import ConsumptionService
MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto") MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto")
MQTT_PORT = int(os.getenv("MQTT_PORT", 1883)) MQTT_PORT = int(os.getenv("MQTT_PORT", 1883))
MQTT_TOPIC = os.getenv("MQTT_TOPIC", "application/+/device/+/event/up") MQTT_TOPIC = os.getenv("MQTT_TOPIC", "application/+/device/+/event/up")
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
@lru_cache
def get_conn():
return connect(DB_URI)
## Repositories ## Repositories
def get_device_repo() -> PgDeviceRepository: def get_device_repo() -> PgDeviceRepository:
return PgDeviceRepository(get_conn()) return PgDeviceRepository()
def get_reading_repo() -> PgReadingRepository: def get_reading_repo() -> PgReadingRepository:
return PgReadingRepository(get_conn()) return PgReadingRepository()
def get_query_repo() -> PgReadingQueryRepository: def get_query_repo() -> PgReadingQueryRepository:
return PgReadingQueryRepository(get_conn()) return PgReadingQueryRepository()
## Services ## Services

View file

@ -0,0 +1,57 @@
import os
import time
import logging
import psycopg2
from typing import Generator, Any
from psycopg2.pool import ThreadedConnectionPool
from psycopg2.extensions import connection
from contextlib import contextmanager
from functools import lru_cache
from domain.exceptions import DatabaseConnectionError
log = logging.getLogger(__name__)
_MIN_CONN = 1
_MAX_CONN = 10
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
@lru_cache(maxsize=1)
def get_pool() -> ThreadedConnectionPool:
return _create_pool(DB_URI)
def _create_pool(
uri: str, retries: int = 10, base_delay: float = 1.0
) -> ThreadedConnectionPool:
for attempt in range(retries):
try:
pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, uri)
log.info(
"Pool PostgreSQL initialisé (%d%d connexions)", _MIN_CONN, _MAX_CONN
)
return pool
except psycopg2.OperationalError as e:
delay = min(base_delay * 2**attempt, 30.0)
log.warning(
"Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e
)
time.sleep(delay)
raise DatabaseConnectionError(
f"Impossible de se connecter après {retries} tentatives"
)
@contextmanager
def get_conn() -> Generator[connection, Any, None]:
pool = get_pool()
conn = pool.getconn() # type: ignore
try:
yield conn
except Exception:
conn.rollback() # type: ignore
raise
finally:
pool.putconn(conn) # type: ignore