O2Sound_ver2_final/backend/app/core/database.py

181 lines
5.0 KiB
Python

from __future__ import annotations
from sqlalchemy import create_engine, Engine, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from app.shared.logger import setup_logger
from typing import Generator
from sqlalchemy.orm import declarative_base
import abc
from app.core.env_setting import EnvSetting
logger = setup_logger(__name__)
Base = declarative_base()
settings = EnvSetting()
# def get_database_url() -> str:
# """PostgreSQL 데이터베이스 URL 생성"""
# return (
# f"postgresql://{settings.DATABASE_USERNAME}:"
# f"{settings.DATABASE_PASSWORD}@{settings.DATABASE_ADDRESS}/"
# f"{settings.DATABASE_NAME}"
# )
def create_engine_by_env():
db_url = get_database_url()
return create_engine(db_url, echo=True) # echo는 로그 확인용
def get_session_local():
engine = create_engine_by_env()
return sessionmaker(autocommit=False, autoflush=False, bind=engine)
# engine = create_engine(settings.DATABASE_URL)
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db() -> Generator[Session, None, None]:
SessionLocal = get_session_local()
db = SessionLocal()
logger.debug("DB session created")
try:
yield db
except Exception as e:
db.rollback()
logger.error(f"DB session error: {str(e)}")
raise
finally:
db.close()
logger.debug("DB session closed")
def get_database_url() -> str:
"""PostgreSQL 데이터베이스 URL 생성"""
return (
f"postgresql://{settings.DATABASE_USERNAME}:"
f"{settings.DATABASE_PASSWORD}@{settings.DATABASE_ADDRESS}/"
f"{settings.DATABASE_NAME}"
)
def create_database_engine():
"""데이터베이스 엔진 생성"""
database_url = get_database_url()
engine = create_engine(
database_url,
echo=True if settings.ENVIRONMENT == "local" else False, # 로컬에서만 SQL 로그 출력
pool_pre_ping=True, # 연결 상태 확인
pool_recycle=3600, # 1시간마다 연결 재생성
pool_size=10, # 연결 풀 크기
max_overflow=20 # 최대 추가 연결 수
)
return engine
# 엔진 인스턴스
engine = create_database_engine()
# 세션 팩토리 생성
SessionLocal = sessionmaker(
bind=engine,
autocommit=settings.DATABASE_AUTO_COMMIT,
autoflush=settings.DATABASE_AUTO_FLUSH,
expire_on_commit=False
)
# def get_db() -> Generator[Session, None, None]:
# """
# 데이터베이스 세션 의존성 주입 함수
# FastAPI Depends와 함께 사용
# """
# db = SessionLocal()
# try:
# yield db
# finally:
# db.close()
class DatabaseManager:
"""데이터베이스 관리자 클래스"""
@staticmethod
def get_session() -> Session:
"""새로운 데이터베이스 세션 반환"""
return SessionLocal()
@staticmethod
def close_session(db: Session) -> None:
"""데이터베이스 세션 종료"""
db.close()
@staticmethod
def commit_and_refresh(db: Session, instance) -> None:
"""커밋 후 인스턴스 새로고침"""
db.commit()
db.refresh(instance)
@staticmethod
def rollback(db: Session) -> None:
"""롤백 실행"""
db.rollback()
def create_session_maker(engine: Engine,settings):
return sessionmaker(autocommit=settings.DATABASE_AUTO_COMMIT,
autoflush=settings.DATABASE_AUTO_FLUSH,
bind=engine
)
def create_persistence_by_env():
engine = create_engine_by_env()
session_maker = create_session_maker(engine, settings)
return engine, session_maker
def create_tables():
try:
engine = create_engine_by_env()
Base.metadata.create_all(bind=engine)
logger.info("✅ All tables created successfully.")
except Exception as e:
logger.error(f"❌ Failed to create tables: {str(e)}")
raise
# # 🔥 FK 무시하고 안전하게 전체 테이블 삭제 (PostgreSQL 전용)
# def drop_tables():
# try:
# engine = create_engine_by_env()
# with engine.connect() as conn:
# conn.execute(text("SET session_replication_role = replica;"))
# BaseTable.metadata.drop_all(bind=engine)
# conn.execute(text("SET session_replication_role = DEFAULT;"))
# conn.commit()
# logger.warning("⚠️ All tables dropped successfully with FK constraints disabled temporarily.")
# except Exception as e:
# logger.error(f"❌ Failed to drop tables: {str(e)}")
# raise
# # =======
class AbstractUnitOfWork(abc.ABC):
def __enter__(self) -> AbstractUnitOfWork:
return self
def __exit__(self, *args):
self.rollback()
@abc.abstractmethod
def commit(self):
raise NotImplementedError
@abc.abstractmethod
def rollback(self):
raise NotImplementedError
@abc.abstractmethod
def flush(self):
raise NotImplementedError