feat: use connection pool instead of a single shared connection
This is useful for multitask environments (like the API), to avoid conflicts
This commit is contained in:
parent
a22413bb0b
commit
aa72971627
5 changed files with 95 additions and 62 deletions
|
|
@ -1,35 +1,19 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import connection
|
||||
|
||||
from ports import DeviceRepository, ReadingRepository
|
||||
from domain.exceptions import DatabaseConnectionError, DatabaseError
|
||||
from domain.exceptions import DatabaseError
|
||||
from infrastructure.db import get_conn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def connect(uri: str, retries: int = 10, base_delay: float = 1.0) -> connection:
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
conn = psycopg2.connect(uri)
|
||||
conn.autocommit = True
|
||||
log.info("PostgreSQL connecté")
|
||||
return conn
|
||||
except psycopg2.OperationalError as e:
|
||||
delay = min(base_delay * 2 ** attempt, 30.0)
|
||||
log.warning("Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e)
|
||||
time.sleep(delay)
|
||||
raise DatabaseConnectionError(f"Impossible de se connecter après {retries} tentatives")
|
||||
|
||||
|
||||
class PgDeviceRepository(DeviceRepository):
|
||||
def __init__(self, conn: connection):
|
||||
self._conn = conn
|
||||
|
||||
def get_or_create_device_id(self, dev_eui: str) -> str:
|
||||
try:
|
||||
with self._conn.cursor() as cur:
|
||||
with get_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device (device_eui)
|
||||
|
|
@ -47,12 +31,10 @@ class PgDeviceRepository(DeviceRepository):
|
|||
|
||||
|
||||
class PgReadingRepository(ReadingRepository):
|
||||
def __init__(self, conn: connection):
|
||||
self._conn = conn
|
||||
|
||||
def insert_reading(self, device_id: str, pulse_count: int) -> None:
|
||||
try:
|
||||
with self._conn.cursor() as cur:
|
||||
with get_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO reading (device_id, date, pulses)
|
||||
|
|
@ -61,4 +43,6 @@ class PgReadingRepository(ReadingRepository):
|
|||
(device_id, pulse_count),
|
||||
)
|
||||
except psycopg2.DatabaseError as e:
|
||||
raise DatabaseError(f"Erreur d'enregistrement de la télérelève sur le device {device_id}") from e
|
||||
raise DatabaseError(
|
||||
f"Erreur d'enregistrement de la télérelève sur le device {device_id}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ from datetime import datetime
|
|||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import connection
|
||||
|
||||
from domain.entities import ConsumptionPoint
|
||||
from domain.exceptions import DatabaseError
|
||||
from domain.value_objects import Granularity
|
||||
from ports.reading_query_repository import ReadingQueryRepository
|
||||
from infrastructure.db import get_conn
|
||||
|
||||
|
||||
_GRANULARITY_DELTA = {
|
||||
|
|
@ -19,9 +19,6 @@ _GRANULARITY_DELTA = {
|
|||
|
||||
|
||||
class PgReadingQueryRepository(ReadingQueryRepository):
|
||||
def __init__(self, conn: connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def get_consumption(
|
||||
self,
|
||||
dev_eui: str,
|
||||
|
|
@ -54,7 +51,8 @@ class PgReadingQueryRepository(ReadingQueryRepository):
|
|||
ORDER BY period ASC
|
||||
"""
|
||||
try:
|
||||
with self._conn.cursor() as cur:
|
||||
with get_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(query, (date_trunc, dev_eui, adjusted_start, end))
|
||||
rows = cur.fetchall()
|
||||
except psycopg2.DatabaseError as e:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from core.logging import setup_logging
|
||||
from dependencies import get_conn, get_uplink_service, get_mqtt_broker
|
||||
from dependencies import get_uplink_service, get_mqtt_broker
|
||||
|
||||
setup_logging()
|
||||
|
||||
if __name__ == "__main__":
|
||||
conn = get_conn()
|
||||
broker = get_mqtt_broker()
|
||||
uplink = get_uplink_service()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from adapters.postgres import connect, PgDeviceRepository, PgReadingRepository
|
||||
from adapters.postgres import PgDeviceRepository, PgReadingRepository
|
||||
from adapters.postgres_query import PgReadingQueryRepository
|
||||
from adapters.mqtt import PahoMqttBroker
|
||||
from services.uplink_service import UplinkService
|
||||
|
|
@ -10,21 +9,17 @@ from services.consumption_service import ConsumptionService
|
|||
MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto")
|
||||
MQTT_PORT = int(os.getenv("MQTT_PORT", 1883))
|
||||
MQTT_TOPIC = os.getenv("MQTT_TOPIC", "application/+/device/+/event/up")
|
||||
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
|
||||
|
||||
@lru_cache
|
||||
def get_conn():
|
||||
return connect(DB_URI)
|
||||
|
||||
## Repositories
|
||||
def get_device_repo() -> PgDeviceRepository:
|
||||
return PgDeviceRepository(get_conn())
|
||||
return PgDeviceRepository()
|
||||
|
||||
def get_reading_repo() -> PgReadingRepository:
|
||||
return PgReadingRepository(get_conn())
|
||||
return PgReadingRepository()
|
||||
|
||||
def get_query_repo() -> PgReadingQueryRepository:
|
||||
return PgReadingQueryRepository(get_conn())
|
||||
return PgReadingQueryRepository()
|
||||
|
||||
|
||||
## Services
|
||||
|
|
|
|||
57
server/app/infrastructure/db.py
Normal file
57
server/app/infrastructure/db.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
import time
|
||||
import logging
|
||||
import psycopg2
|
||||
from typing import Generator, Any
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
from psycopg2.extensions import connection
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
|
||||
from domain.exceptions import DatabaseConnectionError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_MIN_CONN = 1
|
||||
_MAX_CONN = 10
|
||||
|
||||
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_pool() -> ThreadedConnectionPool:
|
||||
return _create_pool(DB_URI)
|
||||
|
||||
|
||||
def _create_pool(
|
||||
uri: str, retries: int = 10, base_delay: float = 1.0
|
||||
) -> ThreadedConnectionPool:
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, uri)
|
||||
log.info(
|
||||
"Pool PostgreSQL initialisé (%d–%d connexions)", _MIN_CONN, _MAX_CONN
|
||||
)
|
||||
return pool
|
||||
except psycopg2.OperationalError as e:
|
||||
delay = min(base_delay * 2**attempt, 30.0)
|
||||
log.warning(
|
||||
"Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e
|
||||
)
|
||||
time.sleep(delay)
|
||||
raise DatabaseConnectionError(
|
||||
f"Impossible de se connecter après {retries} tentatives"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_conn() -> Generator[connection, Any, None]:
|
||||
pool = get_pool()
|
||||
conn = pool.getconn() # type: ignore
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback() # type: ignore
|
||||
raise
|
||||
finally:
|
||||
pool.putconn(conn) # type: ignore
|
||||
Loading…
Add table
Add a link
Reference in a new issue