Files
kycnotme/pyworker/pyworker/scheduler.py
2025-05-19 22:13:13 +00:00

185 lines
6.0 KiB
Python

"""
Scheduler module for managing task execution with cron.
"""
import signal
import threading
from datetime import datetime
from types import FrameType
from typing import Any, Callable, Dict, List, ParamSpec, TypeVar
from croniter import croniter
from pyworker.database import close_db_pool
from .tasks import (
CommentModerationTask,
ForceTriggersTask,
ServiceScoreRecalculationTask,
TosReviewTask,
UserSentimentTask,
)
from pyworker.utils.app_logging import setup_logging
logger = setup_logging(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class TaskScheduler:
"""Task scheduler for running tasks on a cron schedule."""
def __init__(self):
"""Initialize the task scheduler."""
self.tasks: Dict[str, Dict[str, Any]] = {}
self.running = False
self.threads: List[threading.Thread] = []
self.stop_event = threading.Event()
self.logger = logger
# Set up signal handlers
signal.signal(signal.SIGINT, self._handle_signal)
signal.signal(signal.SIGTERM, self._handle_signal)
def _handle_signal(self, signum: int, frame: FrameType | None) -> None:
"""Handle termination signals."""
self.logger.info(f"Received signal {signum}, shutting down...")
self.stop()
def register_task(
self,
task_name: str,
cron_expression: str,
task_func: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Register a task to be scheduled.
Args:
task_name: Name of the task.
cron_expression: Cron expression defining the schedule.
task_func: Function to execute.
*args: Arguments to pass to the task function.
**kwargs: Keyword arguments to pass to the task function.
"""
# Declare task_instance variable with type annotation upfront
task_instance: Any = None
# Initialize the appropriate task class based on the task name
if task_name.lower() == "tosreview":
task_instance = TosReviewTask()
elif task_name.lower() == "user_sentiment":
task_instance = UserSentimentTask()
elif task_name.lower() == "comment_moderation":
task_instance = CommentModerationTask()
elif task_name.lower() == "force_triggers":
task_instance = ForceTriggersTask()
elif task_name.lower() == "service_score_recalc":
task_instance = ServiceScoreRecalculationTask()
else:
self.logger.warning(f"Unknown task '{task_name}', skipping")
return
self.tasks[task_name] = {
"cron": cron_expression,
"func": task_func,
"instance": task_instance,
"args": args,
"kwargs": kwargs,
}
self.logger.info(
f"Registered task '{task_name}' with schedule: {cron_expression}"
)
def _run_task(self, task_name: str, task_info: Dict[str, Any]):
"""
Run a task on its schedule.
Args:
task_name: Name of the task.
task_info: Task information including function and schedule.
"""
self.logger.info(f"Starting scheduler for task '{task_name}'")
# Parse the cron expression
cron = croniter(task_info["cron"], datetime.now())
while not self.stop_event.is_set():
# Get the next run time
next_run = cron.get_next(datetime)
self.logger.info(f"Next run for task '{task_name}': {next_run}")
# Sleep until the next run time
now = datetime.now()
sleep_seconds = (next_run - now).total_seconds()
if sleep_seconds > 0:
# Wait until next run time or until stop event is set
if self.stop_event.wait(sleep_seconds):
break
# Run the task if we haven't been stopped
if not self.stop_event.is_set():
try:
self.logger.info(f"Running task '{task_name}'")
# Use task instance as a context manager to ensure
# a single database connection is used for the entire task
with task_info["instance"]:
# Execute the registered task function with its arguments
task_info["func"](*task_info["args"], **task_info["kwargs"])
self.logger.info(f"Task '{task_name}' completed")
except Exception as e:
self.logger.exception(f"Error running task '{task_name}': {e}")
finally:
# Close the database pool after task execution
close_db_pool()
def start(self):
"""Start the scheduler."""
if self.running:
self.logger.warning("Scheduler is already running")
return
self.logger.info("Starting scheduler")
self.running = True
self.stop_event.clear()
# Start a thread for each task
for task_name, task_info in self.tasks.items():
thread = threading.Thread(
target=self._run_task,
args=(task_name, task_info),
name=f"scheduler-{task_name}",
)
thread.daemon = True
thread.start()
self.threads.append(thread)
self.logger.info(f"Started {len(self.threads)} scheduler threads")
def stop(self):
"""Stop the scheduler."""
if not self.running:
return
self.logger.info("Stopping scheduler")
self.running = False
self.stop_event.set()
# Wait for all threads to terminate
for thread in self.threads:
thread.join(timeout=5.0)
self.threads = []
# Close database pool when the scheduler stops
close_db_pool()
self.logger.info("Scheduler stopped")
def is_running(self) -> bool:
"""Check if the scheduler is running."""
return self.running