O2Sound_ver2_final/backend/app/workers/progress_tracker.py

332 lines
13 KiB
Python

import json
import time
import logging
from datetime import datetime
from typing import Optional, Dict, Any, Set
from app.core.redis.redis_manager import RedisManager
logger = logging.getLogger(__name__)
class ProgressTracker:
"""Redis를 사용하여 Celery 작업의 진행 상태를 추적하는 클래스"""
# 상수 정의
PROGRESS_KEY_PREFIX = "task_progress:"
TIMING_KEY_PREFIX = "task_timing:"
ORDER_ID_KEY_PREFIX = "task_order_id:"
DEFAULT_STATUS = {
"crawling": False,
"lyrics": False,
"music": False,
"images": False,
"video": False,
"cleanup": False
}
VALID_STEPS: Set[str] = set(DEFAULT_STATUS.keys())
def __init__(self, redis_manager: RedisManager):
self.redis_manager = redis_manager
def _get_redis_key(self, task_id: str, key_type: str = "progress") -> str:
"""Redis 키 생성"""
if key_type == "timing":
return f"{self.TIMING_KEY_PREFIX}{task_id}"
return f"{self.PROGRESS_KEY_PREFIX}{task_id}"
def _get_redis_client(self):
"""Redis 클라이언트 안전하게 가져오기"""
try:
return self.redis_manager.get_sync_client()
except Exception as e:
logger.error(f"Redis 클라이언트 연결 실패: {e}")
raise
def _init_task_status(self, task_id: str) -> None:
"""처음 상태 초기화"""
try:
client = self._get_redis_client()
# 진행률 상태 초기화
progress_key = self._get_redis_key(task_id, "progress")
client.set(progress_key, json.dumps(self.DEFAULT_STATUS))
# 🔥 타이밍 정보 초기화
timing_key = self._get_redis_key(task_id, "timing")
timing_data = {
"workflow_start_time": time.time(),
"workflow_end_time": None,
"total_duration": None,
"step_timings": {},
"step_durations": {}
}
client.set(timing_key, json.dumps(timing_data))
logger.info(f"작업 상태 초기화 완료 - Task ID: {task_id}")
except Exception as e:
logger.error(f"작업 상태 초기화 실패 - Task ID: {task_id}, Error: {e}")
raise
def start_step_timing(self, task_id: str, step_name: str) -> None:
"""🔥 단계 시작 시간 기록"""
if step_name not in self.VALID_STEPS:
return
try:
client = self._get_redis_client()
timing_key = self._get_redis_key(task_id, "timing")
data = client.get(timing_key)
if not data:
self._init_task_status(task_id)
data = client.get(timing_key)
if isinstance(data, bytes):
data = data.decode('utf-8')
timing_data = json.loads(data)
# 단계 시작 시간 기록
timing_data["step_timings"][f"{step_name}_start"] = time.time()
client.set(timing_key, json.dumps(timing_data))
logger.debug(f"단계 시작 시간 기록 - Task: {task_id}, Step: {step_name}")
except Exception as e:
logger.warning(f"단계 시작 시간 기록 실패: {e}")
def set_task_step_done(self, task_id: str, step_name: str) -> bool:
"""단계별 상태 True로 업데이트 + 실행시간 기록"""
if not task_id or not step_name:
logger.warning(f"잘못된 입력값 - task_id: {task_id}, step_name: {step_name}")
return False
if step_name not in self.VALID_STEPS:
logger.warning(f"알 수 없는 step_name: {step_name}. 유효한 값: {self.VALID_STEPS}")
return False
try:
client = self._get_redis_client()
progress_key = self._get_redis_key(task_id, "progress")
timing_key = self._get_redis_key(task_id, "timing")
# 진행률 상태 업데이트
progress_data = client.get(progress_key)
if not progress_data:
self._init_task_status(task_id)
progress_data = client.get(progress_key)
if isinstance(progress_data, bytes):
progress_data = progress_data.decode('utf-8')
status_data = json.loads(progress_data)
status_data[step_name] = True
client.set(progress_key, json.dumps(status_data))
# 🔥 타이밍 정보 업데이트
timing_data = client.get(timing_key)
if timing_data:
if isinstance(timing_data, bytes):
timing_data = timing_data.decode('utf-8')
timing_info = json.loads(timing_data)
# 단계 완료 시간 및 소요시간 계산
end_time = time.time()
timing_info["step_timings"][f"{step_name}_end"] = end_time
start_time = timing_info["step_timings"].get(f"{step_name}_start")
if start_time:
duration = end_time - start_time
timing_info["step_durations"][step_name] = duration
logger.info(f"단계 완료 - Task: {task_id}, Step: {step_name}, 소요시간: {duration:.2f}")
# 전체 워크플로우 완료 체크
if all(status_data.values()):
timing_info["workflow_end_time"] = end_time
total_duration = end_time - timing_info["workflow_start_time"]
timing_info["total_duration"] = total_duration
logger.info(f"🎉 전체 워크플로우 완료 - Task: {task_id}, 총 소요시간: {total_duration:.2f}")
client.set(timing_key, json.dumps(timing_info))
return True
except Exception as e:
logger.error(f"작업 단계 업데이트 실패 - Task ID: {task_id}, Step: {step_name}, Error: {e}")
return False
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""전체 상태 조회"""
if not task_id:
logger.warning("task_id가 비어있습니다.")
return {}
try:
client = self._get_redis_client()
progress_key = self._get_redis_key(task_id, "progress")
data = client.get(progress_key)
if not data:
return {}
if isinstance(data, bytes):
data = data.decode('utf-8')
status_data = json.loads(data)
if not isinstance(status_data, dict):
logger.warning(f"잘못된 상태 데이터 형식 - Task ID: {task_id}")
return {}
return status_data
except Exception as e:
logger.error(f"작업 상태 조회 실패 - Task ID: {task_id}, Error: {e}")
return {}
def get_task_timing_info(self, task_id: str) -> Dict[str, Any]:
"""🔥 작업 타이밍 정보 조회"""
if not task_id:
return {}
try:
client = self._get_redis_client()
timing_key = self._get_redis_key(task_id, "timing")
data = client.get(timing_key)
if not data:
return {}
if isinstance(data, bytes):
data = data.decode('utf-8')
timing_data = json.loads(data)
# 사람이 읽기 쉬운 형태로 변환
result = {
"workflow_start_time": self._format_timestamp(timing_data.get("workflow_start_time")),
"workflow_end_time": self._format_timestamp(timing_data.get("workflow_end_time")),
"total_duration": self._format_duration(timing_data.get("total_duration")),
"step_durations": {},
"is_completed": timing_data.get("workflow_end_time") is not None
}
# 각 단계별 소요시간을 읽기 쉽게 변환
for step_name, duration in timing_data.get("step_durations", {}).items():
result["step_durations"][step_name] = self._format_duration(duration)
return result
except Exception as e:
logger.error(f"타이밍 정보 조회 실패 - Task ID: {task_id}, Error: {e}")
return {}
def get_progress_percentage(self, task_id: str) -> float:
"""작업 진행률을 퍼센트로 반환"""
status = self.get_task_status(task_id)
if not status:
return 0.0
completed_steps = sum(1 for step_done in status.values() if step_done)
total_steps = len(self.DEFAULT_STATUS)
return (completed_steps / total_steps) * 100.0
def is_task_completed(self, task_id: str) -> bool:
"""작업이 완전히 완료되었는지 확인"""
status = self.get_task_status(task_id)
if not status:
return False
return all(status.values())
def delete_task_status(self, task_id: str) -> bool:
"""작업 상태 삭제 (타이밍 정보와 order_id도 함께 삭제)"""
if not task_id:
return False
try:
client = self._get_redis_client()
progress_key = self._get_redis_key(task_id, "progress")
timing_key = self._get_redis_key(task_id, "timing")
order_id_key = self._get_order_id_key(task_id)
result1 = client.delete(progress_key)
result2 = client.delete(timing_key)
result3 = client.delete(order_id_key)
if result1 or result2 or result3:
logger.info(f"작업 상태 삭제 완료 - Task ID: {task_id}")
return True
else:
logger.warning(f"삭제할 작업 상태가 없습니다 - Task ID: {task_id}")
return False
except Exception as e:
logger.error(f"작업 상태 삭제 실패 - Task ID: {task_id}, Error: {e}")
return False
def _format_timestamp(self, timestamp: Optional[float]) -> Optional[str]:
"""타임스탬프를 읽기 쉬운 형태로 변환"""
if timestamp is None:
return None
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
def _format_duration(self, duration: Optional[float]) -> Optional[str]:
"""소요시간을 읽기 쉬운 형태로 변환"""
if duration is None:
return None
if duration < 60:
return f"{duration:.2f}"
elif duration < 3600:
minutes = int(duration // 60)
seconds = duration % 60
return f"{minutes}{seconds:.2f}"
else:
hours = int(duration // 3600)
minutes = int((duration % 3600) // 60)
seconds = duration % 60
return f"{hours}시간 {minutes}{seconds:.2f}"
def _get_order_id_key(self, task_id: str) -> str:
"""order_id Redis 키 생성"""
return f"{self.ORDER_ID_KEY_PREFIX}{task_id}"
def set_order_id(self, task_id: str, order_id: str) -> bool:
"""task_id에 대응하는 order_id 저장"""
if not task_id or not order_id or str(order_id) == "None":
logger.warning(f"잘못된 입력값 - task_id: {task_id}, order_id: {order_id}")
return False
try:
client = self._get_redis_client()
key = self._get_order_id_key(task_id)
client.set(key, str(order_id))
logger.info(f"Order ID 저장 완료 - Task ID: {task_id}, Order ID: {order_id}")
return True
except Exception as e:
logger.error(f"Order ID 저장 실패 - Task ID: {task_id}, Error: {e}")
return False
def get_order_id(self, task_id: str) -> Optional[str]:
"""task_id에 대응하는 order_id 조회"""
if not task_id:
logger.warning("task_id가 비어있습니다.")
return None
try:
client = self._get_redis_client()
key = self._get_order_id_key(task_id)
data = client.get(key)
if not data:
logger.warning(f"Order ID를 찾을 수 없음 - Task ID: {task_id}")
return None
if isinstance(data, bytes):
data = data.decode('utf-8')
return data
except Exception as e:
logger.error(f"Order ID 조회 실패 - Task ID: {task_id}, Error: {e}")
return None