diff --git a/server/app/adapters/postgres.py b/server/app/adapters/postgres.py index ff1bd8b..fb9cc60 100644 --- a/server/app/adapters/postgres.py +++ b/server/app/adapters/postgres.py @@ -1,64 +1,48 @@ import logging -import time import psycopg2 -from psycopg2.extensions import connection from ports import DeviceRepository, ReadingRepository -from domain.exceptions import DatabaseConnectionError, DatabaseError +from domain.exceptions import DatabaseError +from infrastructure.db import get_conn log = logging.getLogger(__name__) -def connect(uri: str, retries: int = 10, base_delay: float = 1.0) -> connection: - for attempt in range(retries): - try: - conn = psycopg2.connect(uri) - conn.autocommit = True - log.info("PostgreSQL connecté") - return conn - except psycopg2.OperationalError as e: - delay = min(base_delay * 2 ** attempt, 30.0) - log.warning("Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e) - time.sleep(delay) - raise DatabaseConnectionError(f"Impossible de se connecter après {retries} tentatives") - class PgDeviceRepository(DeviceRepository): - def __init__(self, conn: connection): - self._conn = conn - def get_or_create_device_id(self, dev_eui: str) -> str: try: - with self._conn.cursor() as cur: - cur.execute( - """ - INSERT INTO device (device_eui) - VALUES (%s) - ON CONFLICT (device_eui) DO NOTHING - """, - (dev_eui,), - ) - cur.execute( - "SELECT device_id FROM device WHERE device_eui = %s", (dev_eui,) - ) - return str(cur.fetchone()[0]) # type: ignore + with get_conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO device (device_eui) + VALUES (%s) + ON CONFLICT (device_eui) DO NOTHING + """, + (dev_eui,), + ) + cur.execute( + "SELECT device_id FROM device WHERE device_eui = %s", (dev_eui,) + ) + return str(cur.fetchone()[0]) # type: ignore except psycopg2.DatabaseError as e: raise DatabaseError(f"Erreur de création du device {dev_eui}") from e class PgReadingRepository(ReadingRepository): - def __init__(self, conn: connection): - self._conn = conn - def insert_reading(self, device_id: str, pulse_count: int) -> None: try: - with self._conn.cursor() as cur: - cur.execute( - """ - INSERT INTO reading (device_id, date, pulses) - VALUES (%s, NOW(), %s) - """, - (device_id, pulse_count), - ) + with get_conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO reading (device_id, date, pulses) + VALUES (%s, NOW(), %s) + """, + (device_id, pulse_count), + ) except psycopg2.DatabaseError as e: - raise DatabaseError(f"Erreur d'enregistrement de la télérelève sur le device {device_id}") from e + raise DatabaseError( + f"Erreur d'enregistrement de la télérelève sur le device {device_id}" + ) from e diff --git a/server/app/adapters/postgres_query.py b/server/app/adapters/postgres_query.py index 01e8f0b..b1143b6 100644 --- a/server/app/adapters/postgres_query.py +++ b/server/app/adapters/postgres_query.py @@ -2,12 +2,12 @@ from datetime import datetime from dateutil.relativedelta import relativedelta import psycopg2 -from psycopg2.extensions import connection from domain.entities import ConsumptionPoint from domain.exceptions import DatabaseError from domain.value_objects import Granularity from ports.reading_query_repository import ReadingQueryRepository +from infrastructure.db import get_conn _GRANULARITY_DELTA = { @@ -19,9 +19,6 @@ _GRANULARITY_DELTA = { class PgReadingQueryRepository(ReadingQueryRepository): - def __init__(self, conn: connection) -> None: - self._conn = conn - def get_consumption( self, dev_eui: str, @@ -54,9 +51,10 @@ class PgReadingQueryRepository(ReadingQueryRepository): ORDER BY period ASC """ try: - with self._conn.cursor() as cur: - cur.execute(query, (date_trunc, dev_eui, adjusted_start, end)) - rows = cur.fetchall() + with get_conn() as conn: + with conn.cursor() as cur: + cur.execute(query, (date_trunc, dev_eui, adjusted_start, end)) + rows = cur.fetchall() except psycopg2.DatabaseError as e: raise DatabaseError(f"Erreur requête consumption : {e}") from e diff --git a/server/app/consumer.py b/server/app/consumer.py index 28ffcf9..8787644 100644 --- a/server/app/consumer.py +++ b/server/app/consumer.py @@ -1,10 +1,9 @@ from core.logging import setup_logging -from dependencies import get_conn, get_uplink_service, get_mqtt_broker +from dependencies import get_uplink_service, get_mqtt_broker setup_logging() if __name__ == "__main__": - conn = get_conn() broker = get_mqtt_broker() uplink = get_uplink_service() diff --git a/server/app/dependencies.py b/server/app/dependencies.py index ec4d0cc..b4160d5 100644 --- a/server/app/dependencies.py +++ b/server/app/dependencies.py @@ -1,7 +1,6 @@ import os -from functools import lru_cache -from adapters.postgres import connect, PgDeviceRepository, PgReadingRepository +from adapters.postgres import PgDeviceRepository, PgReadingRepository from adapters.postgres_query import PgReadingQueryRepository from adapters.mqtt import PahoMqttBroker from services.uplink_service import UplinkService @@ -10,21 +9,17 @@ from services.consumption_service import ConsumptionService MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto") MQTT_PORT = int(os.getenv("MQTT_PORT", 1883)) MQTT_TOPIC = os.getenv("MQTT_TOPIC", "application/+/device/+/event/up") -DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz") -@lru_cache -def get_conn(): - return connect(DB_URI) ## Repositories def get_device_repo() -> PgDeviceRepository: - return PgDeviceRepository(get_conn()) + return PgDeviceRepository() def get_reading_repo() -> PgReadingRepository: - return PgReadingRepository(get_conn()) + return PgReadingRepository() def get_query_repo() -> PgReadingQueryRepository: - return PgReadingQueryRepository(get_conn()) + return PgReadingQueryRepository() ## Services diff --git a/server/app/infrastructure/db.py b/server/app/infrastructure/db.py new file mode 100644 index 0000000..0f7b6d8 --- /dev/null +++ b/server/app/infrastructure/db.py @@ -0,0 +1,57 @@ +import os +import time +import logging +import psycopg2 +from typing import Generator, Any +from psycopg2.pool import ThreadedConnectionPool +from psycopg2.extensions import connection +from contextlib import contextmanager +from functools import lru_cache + +from domain.exceptions import DatabaseConnectionError + +log = logging.getLogger(__name__) + +_MIN_CONN = 1 +_MAX_CONN = 10 + +DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz") + + +@lru_cache(maxsize=1) +def get_pool() -> ThreadedConnectionPool: + return _create_pool(DB_URI) + + +def _create_pool( + uri: str, retries: int = 10, base_delay: float = 1.0 +) -> ThreadedConnectionPool: + for attempt in range(retries): + try: + pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, uri) + log.info( + "Pool PostgreSQL initialisé (%d–%d connexions)", _MIN_CONN, _MAX_CONN + ) + return pool + except psycopg2.OperationalError as e: + delay = min(base_delay * 2**attempt, 30.0) + log.warning( + "Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e + ) + time.sleep(delay) + raise DatabaseConnectionError( + f"Impossible de se connecter après {retries} tentatives" + ) + + +@contextmanager +def get_conn() -> Generator[connection, Any, None]: + pool = get_pool() + conn = pool.getconn() # type: ignore + try: + yield conn + except Exception: + conn.rollback() # type: ignore + raise + finally: + pool.putconn(conn) # type: ignore