feat: use connection pool instead of a single shared connection
This is useful for multitask environments (like the API), to avoid conflicts
This commit is contained in:
parent
a22413bb0b
commit
aa72971627
5 changed files with 95 additions and 62 deletions
|
|
@ -1,35 +1,19 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extensions import connection
|
|
||||||
|
|
||||||
from ports import DeviceRepository, ReadingRepository
|
from ports import DeviceRepository, ReadingRepository
|
||||||
from domain.exceptions import DatabaseConnectionError, DatabaseError
|
from domain.exceptions import DatabaseError
|
||||||
|
from infrastructure.db import get_conn
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
def connect(uri: str, retries: int = 10, base_delay: float = 1.0) -> connection:
|
|
||||||
for attempt in range(retries):
|
|
||||||
try:
|
|
||||||
conn = psycopg2.connect(uri)
|
|
||||||
conn.autocommit = True
|
|
||||||
log.info("PostgreSQL connecté")
|
|
||||||
return conn
|
|
||||||
except psycopg2.OperationalError as e:
|
|
||||||
delay = min(base_delay * 2 ** attempt, 30.0)
|
|
||||||
log.warning("Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e)
|
|
||||||
time.sleep(delay)
|
|
||||||
raise DatabaseConnectionError(f"Impossible de se connecter après {retries} tentatives")
|
|
||||||
|
|
||||||
|
|
||||||
class PgDeviceRepository(DeviceRepository):
|
class PgDeviceRepository(DeviceRepository):
|
||||||
def __init__(self, conn: connection):
|
|
||||||
self._conn = conn
|
|
||||||
|
|
||||||
def get_or_create_device_id(self, dev_eui: str) -> str:
|
def get_or_create_device_id(self, dev_eui: str) -> str:
|
||||||
try:
|
try:
|
||||||
with self._conn.cursor() as cur:
|
with get_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO device (device_eui)
|
INSERT INTO device (device_eui)
|
||||||
|
|
@ -47,12 +31,10 @@ class PgDeviceRepository(DeviceRepository):
|
||||||
|
|
||||||
|
|
||||||
class PgReadingRepository(ReadingRepository):
|
class PgReadingRepository(ReadingRepository):
|
||||||
def __init__(self, conn: connection):
|
|
||||||
self._conn = conn
|
|
||||||
|
|
||||||
def insert_reading(self, device_id: str, pulse_count: int) -> None:
|
def insert_reading(self, device_id: str, pulse_count: int) -> None:
|
||||||
try:
|
try:
|
||||||
with self._conn.cursor() as cur:
|
with get_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO reading (device_id, date, pulses)
|
INSERT INTO reading (device_id, date, pulses)
|
||||||
|
|
@ -61,4 +43,6 @@ class PgReadingRepository(ReadingRepository):
|
||||||
(device_id, pulse_count),
|
(device_id, pulse_count),
|
||||||
)
|
)
|
||||||
except psycopg2.DatabaseError as e:
|
except psycopg2.DatabaseError as e:
|
||||||
raise DatabaseError(f"Erreur d'enregistrement de la télérelève sur le device {device_id}") from e
|
raise DatabaseError(
|
||||||
|
f"Erreur d'enregistrement de la télérelève sur le device {device_id}"
|
||||||
|
) from e
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@ from datetime import datetime
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extensions import connection
|
|
||||||
|
|
||||||
from domain.entities import ConsumptionPoint
|
from domain.entities import ConsumptionPoint
|
||||||
from domain.exceptions import DatabaseError
|
from domain.exceptions import DatabaseError
|
||||||
from domain.value_objects import Granularity
|
from domain.value_objects import Granularity
|
||||||
from ports.reading_query_repository import ReadingQueryRepository
|
from ports.reading_query_repository import ReadingQueryRepository
|
||||||
|
from infrastructure.db import get_conn
|
||||||
|
|
||||||
|
|
||||||
_GRANULARITY_DELTA = {
|
_GRANULARITY_DELTA = {
|
||||||
|
|
@ -19,9 +19,6 @@ _GRANULARITY_DELTA = {
|
||||||
|
|
||||||
|
|
||||||
class PgReadingQueryRepository(ReadingQueryRepository):
|
class PgReadingQueryRepository(ReadingQueryRepository):
|
||||||
def __init__(self, conn: connection) -> None:
|
|
||||||
self._conn = conn
|
|
||||||
|
|
||||||
def get_consumption(
|
def get_consumption(
|
||||||
self,
|
self,
|
||||||
dev_eui: str,
|
dev_eui: str,
|
||||||
|
|
@ -54,7 +51,8 @@ class PgReadingQueryRepository(ReadingQueryRepository):
|
||||||
ORDER BY period ASC
|
ORDER BY period ASC
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with self._conn.cursor() as cur:
|
with get_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
cur.execute(query, (date_trunc, dev_eui, adjusted_start, end))
|
cur.execute(query, (date_trunc, dev_eui, adjusted_start, end))
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
except psycopg2.DatabaseError as e:
|
except psycopg2.DatabaseError as e:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
from core.logging import setup_logging
|
from core.logging import setup_logging
|
||||||
from dependencies import get_conn, get_uplink_service, get_mqtt_broker
|
from dependencies import get_uplink_service, get_mqtt_broker
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
conn = get_conn()
|
|
||||||
broker = get_mqtt_broker()
|
broker = get_mqtt_broker()
|
||||||
uplink = get_uplink_service()
|
uplink = get_uplink_service()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
from adapters.postgres import connect, PgDeviceRepository, PgReadingRepository
|
from adapters.postgres import PgDeviceRepository, PgReadingRepository
|
||||||
from adapters.postgres_query import PgReadingQueryRepository
|
from adapters.postgres_query import PgReadingQueryRepository
|
||||||
from adapters.mqtt import PahoMqttBroker
|
from adapters.mqtt import PahoMqttBroker
|
||||||
from services.uplink_service import UplinkService
|
from services.uplink_service import UplinkService
|
||||||
|
|
@ -10,21 +9,17 @@ from services.consumption_service import ConsumptionService
|
||||||
MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto")
|
MQTT_HOST = os.getenv("MQTT_HOST", "mosquitto")
|
||||||
MQTT_PORT = int(os.getenv("MQTT_PORT", 1883))
|
MQTT_PORT = int(os.getenv("MQTT_PORT", 1883))
|
||||||
MQTT_TOPIC = os.getenv("MQTT_TOPIC", "application/+/device/+/event/up")
|
MQTT_TOPIC = os.getenv("MQTT_TOPIC", "application/+/device/+/event/up")
|
||||||
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def get_conn():
|
|
||||||
return connect(DB_URI)
|
|
||||||
|
|
||||||
## Repositories
|
## Repositories
|
||||||
def get_device_repo() -> PgDeviceRepository:
|
def get_device_repo() -> PgDeviceRepository:
|
||||||
return PgDeviceRepository(get_conn())
|
return PgDeviceRepository()
|
||||||
|
|
||||||
def get_reading_repo() -> PgReadingRepository:
|
def get_reading_repo() -> PgReadingRepository:
|
||||||
return PgReadingRepository(get_conn())
|
return PgReadingRepository()
|
||||||
|
|
||||||
def get_query_repo() -> PgReadingQueryRepository:
|
def get_query_repo() -> PgReadingQueryRepository:
|
||||||
return PgReadingQueryRepository(get_conn())
|
return PgReadingQueryRepository()
|
||||||
|
|
||||||
|
|
||||||
## Services
|
## Services
|
||||||
|
|
|
||||||
57
server/app/infrastructure/db.py
Normal file
57
server/app/infrastructure/db.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import psycopg2
|
||||||
|
from typing import Generator, Any
|
||||||
|
from psycopg2.pool import ThreadedConnectionPool
|
||||||
|
from psycopg2.extensions import connection
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from domain.exceptions import DatabaseConnectionError
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MIN_CONN = 1
|
||||||
|
_MAX_CONN = 10
|
||||||
|
|
||||||
|
DB_URI = os.getenv("DATABASE_URI", "postgresql://simugaz:simugaz@db/simugaz")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_pool() -> ThreadedConnectionPool:
|
||||||
|
return _create_pool(DB_URI)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pool(
|
||||||
|
uri: str, retries: int = 10, base_delay: float = 1.0
|
||||||
|
) -> ThreadedConnectionPool:
|
||||||
|
for attempt in range(retries):
|
||||||
|
try:
|
||||||
|
pool = ThreadedConnectionPool(_MIN_CONN, _MAX_CONN, uri)
|
||||||
|
log.info(
|
||||||
|
"Pool PostgreSQL initialisé (%d–%d connexions)", _MIN_CONN, _MAX_CONN
|
||||||
|
)
|
||||||
|
return pool
|
||||||
|
except psycopg2.OperationalError as e:
|
||||||
|
delay = min(base_delay * 2**attempt, 30.0)
|
||||||
|
log.warning(
|
||||||
|
"Attente PostgreSQL (tentative %d/%d) : %s", attempt + 1, retries, e
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
raise DatabaseConnectionError(
|
||||||
|
f"Impossible de se connecter après {retries} tentatives"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_conn() -> Generator[connection, Any, None]:
|
||||||
|
pool = get_pool()
|
||||||
|
conn = pool.getconn() # type: ignore
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
except Exception:
|
||||||
|
conn.rollback() # type: ignore
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
pool.putconn(conn) # type: ignore
|
||||||
Loading…
Add table
Add a link
Reference in a new issue