feat: use connection pool instead of a single shared connection

This is useful for multitask environments (like the API), to avoid
conflicts
This commit is contained in:
Alexis Fourmaux 2026-05-10 14:55:50 +02:00
parent a22413bb0b
commit aa72971627
5 changed files with 95 additions and 62 deletions

View file

@ -1,64 +1,48 @@
import logging
import time
import psycopg2
from psycopg2.extensions import connection
from ports import DeviceRepository, ReadingRepository
from domain.exceptions import DatabaseConnectionError, DatabaseError
from domain.exceptions import DatabaseError
from infrastructure.db import get_conn
log = logging.getLogger(__name__)
def connect(uri: str, retries: int = 10, base_delay: float = 1.0) -> connection:
for attempt in range(retries):
try:
conn = psycopg2.connect(uri)
conn.autocommit = True
log.info("PostgreSQL connecté")
return conn
except psycopg2.OperationalError as e:
delay = min(base_delay * 2 ** attempt, 30.0)
log.warning("Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e)
time.sleep(delay)
raise DatabaseConnectionError(f"Impossible de se connecter après {retries} tentatives")
class PgDeviceRepository(DeviceRepository):
def __init__(self, conn: connection):
self._conn = conn
def get_or_create_device_id(self, dev_eui: str) -> str:
try:
with self._conn.cursor() as cur:
cur.execute(
"""
INSERT INTO device (device_eui)
VALUES (%s)
ON CONFLICT (device_eui) DO NOTHING
""",
(dev_eui,),
)
cur.execute(
"SELECT device_id FROM device WHERE device_eui = %s", (dev_eui,)
)
return str(cur.fetchone()[0]) # type: ignore
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO device (device_eui)
VALUES (%s)
ON CONFLICT (device_eui) DO NOTHING
""",
(dev_eui,),
)
cur.execute(
"SELECT device_id FROM device WHERE device_eui = %s", (dev_eui,)
)
return str(cur.fetchone()[0]) # type: ignore
except psycopg2.DatabaseError as e:
raise DatabaseError(f"Erreur de création du device {dev_eui}") from e
class PgReadingRepository(ReadingRepository):
def __init__(self, conn: connection):
self._conn = conn
def insert_reading(self, device_id: str, pulse_count: int) -> None:
try:
with self._conn.cursor() as cur:
cur.execute(
"""
INSERT INTO reading (device_id, date, pulses)
VALUES (%s, NOW(), %s)
""",
(device_id, pulse_count),
)
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO reading (device_id, date, pulses)
VALUES (%s, NOW(), %s)
""",
(device_id, pulse_count),
)
except psycopg2.DatabaseError as e:
raise DatabaseError(f"Erreur d'enregistrement de la télérelève sur le device {device_id}") from e
raise DatabaseError(
f"Erreur d'enregistrement de la télérelève sur le device {device_id}"
) from e

View file

@ -2,12 +2,12 @@ from datetime import datetime
from dateutil.relativedelta import relativedelta
import psycopg2
from psycopg2.extensions import connection
from domain.entities import ConsumptionPoint
from domain.exceptions import DatabaseError
from domain.value_objects import Granularity
from ports.reading_query_repository import ReadingQueryRepository
from infrastructure.db import get_conn
_GRANULARITY_DELTA = {
@ -19,9 +19,6 @@ _GRANULARITY_DELTA = {
class PgReadingQueryRepository(ReadingQueryRepository):
def __init__(self, conn: connection) -> None:
self._conn = conn
def get_consumption(
self,
dev_eui: str,
@ -54,9 +51,10 @@ class PgReadingQueryRepository(ReadingQueryRepository):
ORDER BY period ASC
"""
try:
with self._conn.cursor() as cur:
cur.execute(query, (date_trunc, dev_eui, adjusted_start, end))
rows = cur.fetchall()
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(query, (date_trunc, dev_eui, adjusted_start, end))
rows = cur.fetchall()
except psycopg2.DatabaseError as e:
raise DatabaseError(f"Erreur requête consumption : {e}") from e

View file

@ -1,10 +1,9 @@
from core.logging import setup_logging
from dependencies import get_conn, get_uplink_service, get_mqtt_broker
from dependencies import get_uplink_service, get_mqtt_broker
setup_logging()
if __name__ == "__main__":
conn = get_conn()
broker = get_mqtt_broker()
uplink = get_uplink_service()

View file

@ -1,7 +1,6 @@
import os
from functools import lru_cache
from adapters.postgres import connect, PgDeviceRepository, PgReadingRepository
from adapters.postgres import PgDeviceRepository, PgReadingRepository
from adapters.postgres_query import PgReadingQueryRepository
from adapters.mqtt import PahoMqttBroker
from services.uplink_service import UplinkService
@ -10,21 +9,17 @@ from services.consumption_service import ConsumptionService
MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto")
MQTT_PORT = int(os.getenv("MQTT_PORT", 1883))
MQTT_TOPIC = os.getenv("MQTT_TOPIC", "application/+/device/+/event/up")
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
@lru_cache
def get_conn():
return connect(DB_URI)
## Repositories
def get_device_repo() -> PgDeviceRepository:
return PgDeviceRepository(get_conn())
return PgDeviceRepository()
def get_reading_repo() -> PgReadingRepository:
return PgReadingRepository(get_conn())
return PgReadingRepository()
def get_query_repo() -> PgReadingQueryRepository:
return PgReadingQueryRepository(get_conn())
return PgReadingQueryRepository()
## Services

View file

@ -0,0 +1,57 @@
import os
import time
import logging
import psycopg2
from typing import Generator, Any
from psycopg2.pool import ThreadedConnectionPool
from psycopg2.extensions import connection
from contextlib import contextmanager
from functools import lru_cache
from domain.exceptions import DatabaseConnectionError
log = logging.getLogger(__name__)
_MIN_CONN = 1
_MAX_CONN = 10
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
@lru_cache(maxsize=1)
def get_pool() -> ThreadedConnectionPool:
return _create_pool(DB_URI)
def _create_pool(
uri: str, retries: int = 10, base_delay: float = 1.0
) -> ThreadedConnectionPool:
for attempt in range(retries):
try:
pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, uri)
log.info(
"Pool PostgreSQL initialisé (%d%d connexions)", _MIN_CONN, _MAX_CONN
)
return pool
except psycopg2.OperationalError as e:
delay = min(base_delay * 2**attempt, 30.0)
log.warning(
"Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e
)
time.sleep(delay)
raise DatabaseConnectionError(
f"Impossible de se connecter après {retries} tentatives"
)
@contextmanager
def get_conn() -> Generator[connection, Any, None]:
pool = get_pool()
conn = pool.getconn() # type: ignore
try:
yield conn
except Exception:
conn.rollback() # type: ignore
raise
finally:
pool.putconn(conn) # type: ignore