181 lines
5.0 KiB
Python
181 lines
5.0 KiB
Python
from __future__ import annotations
|
|
from sqlalchemy import create_engine, Engine, text
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from app.shared.logger import setup_logger
|
|
from typing import Generator
|
|
from sqlalchemy.orm import declarative_base
|
|
import abc
|
|
from app.core.env_setting import EnvSetting
|
|
|
|
|
|
logger = setup_logger(__name__)
|
|
|
|
Base = declarative_base()
|
|
|
|
settings = EnvSetting()
|
|
|
|
|
|
# def get_database_url() -> str:
|
|
# """PostgreSQL 데이터베이스 URL 생성"""
|
|
# return (
|
|
# f"postgresql://{settings.DATABASE_USERNAME}:"
|
|
# f"{settings.DATABASE_PASSWORD}@{settings.DATABASE_ADDRESS}/"
|
|
# f"{settings.DATABASE_NAME}"
|
|
# )
|
|
|
|
|
|
|
|
def create_engine_by_env():
|
|
db_url = get_database_url()
|
|
return create_engine(db_url, echo=True) # echo는 로그 확인용
|
|
|
|
|
|
def get_session_local():
|
|
engine = create_engine_by_env()
|
|
return sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
# engine = create_engine(settings.DATABASE_URL)
|
|
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
def get_db() -> Generator[Session, None, None]:
|
|
SessionLocal = get_session_local()
|
|
db = SessionLocal()
|
|
logger.debug("DB session created")
|
|
try:
|
|
yield db
|
|
except Exception as e:
|
|
db.rollback()
|
|
logger.error(f"DB session error: {str(e)}")
|
|
raise
|
|
finally:
|
|
db.close()
|
|
logger.debug("DB session closed")
|
|
|
|
|
|
def get_database_url() -> str:
|
|
"""PostgreSQL 데이터베이스 URL 생성"""
|
|
return (
|
|
f"postgresql://{settings.DATABASE_USERNAME}:"
|
|
f"{settings.DATABASE_PASSWORD}@{settings.DATABASE_ADDRESS}/"
|
|
f"{settings.DATABASE_NAME}"
|
|
)
|
|
|
|
def create_database_engine():
|
|
"""데이터베이스 엔진 생성"""
|
|
database_url = get_database_url()
|
|
engine = create_engine(
|
|
database_url,
|
|
echo=True if settings.ENVIRONMENT == "local" else False, # 로컬에서만 SQL 로그 출력
|
|
pool_pre_ping=True, # 연결 상태 확인
|
|
pool_recycle=3600, # 1시간마다 연결 재생성
|
|
pool_size=10, # 연결 풀 크기
|
|
max_overflow=20 # 최대 추가 연결 수
|
|
)
|
|
return engine
|
|
|
|
# 엔진 인스턴스
|
|
engine = create_database_engine()
|
|
|
|
# 세션 팩토리 생성
|
|
SessionLocal = sessionmaker(
|
|
bind=engine,
|
|
autocommit=settings.DATABASE_AUTO_COMMIT,
|
|
autoflush=settings.DATABASE_AUTO_FLUSH,
|
|
expire_on_commit=False
|
|
)
|
|
|
|
# def get_db() -> Generator[Session, None, None]:
|
|
# """
|
|
# 데이터베이스 세션 의존성 주입 함수
|
|
# FastAPI Depends와 함께 사용
|
|
# """
|
|
# db = SessionLocal()
|
|
# try:
|
|
# yield db
|
|
# finally:
|
|
# db.close()
|
|
|
|
|
|
class DatabaseManager:
|
|
"""데이터베이스 관리자 클래스"""
|
|
|
|
@staticmethod
|
|
def get_session() -> Session:
|
|
"""새로운 데이터베이스 세션 반환"""
|
|
return SessionLocal()
|
|
|
|
@staticmethod
|
|
def close_session(db: Session) -> None:
|
|
"""데이터베이스 세션 종료"""
|
|
db.close()
|
|
|
|
@staticmethod
|
|
def commit_and_refresh(db: Session, instance) -> None:
|
|
"""커밋 후 인스턴스 새로고침"""
|
|
db.commit()
|
|
db.refresh(instance)
|
|
|
|
@staticmethod
|
|
def rollback(db: Session) -> None:
|
|
"""롤백 실행"""
|
|
db.rollback()
|
|
|
|
def create_session_maker(engine: Engine,settings):
|
|
return sessionmaker(autocommit=settings.DATABASE_AUTO_COMMIT,
|
|
autoflush=settings.DATABASE_AUTO_FLUSH,
|
|
bind=engine
|
|
)
|
|
|
|
def create_persistence_by_env():
|
|
engine = create_engine_by_env()
|
|
session_maker = create_session_maker(engine, settings)
|
|
return engine, session_maker
|
|
|
|
|
|
|
|
def create_tables():
|
|
try:
|
|
engine = create_engine_by_env()
|
|
Base.metadata.create_all(bind=engine)
|
|
logger.info("✅ All tables created successfully.")
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to create tables: {str(e)}")
|
|
raise
|
|
|
|
|
|
# # 🔥 FK 무시하고 안전하게 전체 테이블 삭제 (PostgreSQL 전용)
|
|
# def drop_tables():
|
|
# try:
|
|
# engine = create_engine_by_env()
|
|
# with engine.connect() as conn:
|
|
# conn.execute(text("SET session_replication_role = replica;"))
|
|
# BaseTable.metadata.drop_all(bind=engine)
|
|
# conn.execute(text("SET session_replication_role = DEFAULT;"))
|
|
# conn.commit()
|
|
# logger.warning("⚠️ All tables dropped successfully with FK constraints disabled temporarily.")
|
|
# except Exception as e:
|
|
# logger.error(f"❌ Failed to drop tables: {str(e)}")
|
|
# raise
|
|
|
|
# # =======
|
|
|
|
class AbstractUnitOfWork(abc.ABC):
|
|
|
|
def __enter__(self) -> AbstractUnitOfWork:
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.rollback()
|
|
|
|
@abc.abstractmethod
|
|
def commit(self):
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def rollback(self):
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def flush(self):
|
|
raise NotImplementedError |