185 lines
6.0 KiB
Python
185 lines
6.0 KiB
Python
"""
|
|
Scheduler module for managing task execution with cron.
|
|
"""
|
|
|
|
import signal
|
|
import threading
|
|
from datetime import datetime
|
|
from types import FrameType
|
|
from typing import Any, Callable, Dict, List, ParamSpec, TypeVar
|
|
|
|
from croniter import croniter
|
|
|
|
from pyworker.database import close_db_pool
|
|
from .tasks import (
|
|
CommentModerationTask,
|
|
ForceTriggersTask,
|
|
ServiceScoreRecalculationTask,
|
|
TosReviewTask,
|
|
UserSentimentTask,
|
|
)
|
|
from pyworker.utils.app_logging import setup_logging
|
|
|
|
logger = setup_logging(__name__)
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
class TaskScheduler:
|
|
"""Task scheduler for running tasks on a cron schedule."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the task scheduler."""
|
|
self.tasks: Dict[str, Dict[str, Any]] = {}
|
|
self.running = False
|
|
self.threads: List[threading.Thread] = []
|
|
self.stop_event = threading.Event()
|
|
self.logger = logger
|
|
|
|
# Set up signal handlers
|
|
signal.signal(signal.SIGINT, self._handle_signal)
|
|
signal.signal(signal.SIGTERM, self._handle_signal)
|
|
|
|
def _handle_signal(self, signum: int, frame: FrameType | None) -> None:
|
|
"""Handle termination signals."""
|
|
self.logger.info(f"Received signal {signum}, shutting down...")
|
|
self.stop()
|
|
|
|
def register_task(
|
|
self,
|
|
task_name: str,
|
|
cron_expression: str,
|
|
task_func: Callable[P, R],
|
|
*args: P.args,
|
|
**kwargs: P.kwargs,
|
|
) -> None:
|
|
"""
|
|
Register a task to be scheduled.
|
|
|
|
Args:
|
|
task_name: Name of the task.
|
|
cron_expression: Cron expression defining the schedule.
|
|
task_func: Function to execute.
|
|
*args: Arguments to pass to the task function.
|
|
**kwargs: Keyword arguments to pass to the task function.
|
|
"""
|
|
# Declare task_instance variable with type annotation upfront
|
|
task_instance: Any = None
|
|
|
|
# Initialize the appropriate task class based on the task name
|
|
if task_name.lower() == "tosreview":
|
|
task_instance = TosReviewTask()
|
|
elif task_name.lower() == "user_sentiment":
|
|
task_instance = UserSentimentTask()
|
|
elif task_name.lower() == "comment_moderation":
|
|
task_instance = CommentModerationTask()
|
|
elif task_name.lower() == "force_triggers":
|
|
task_instance = ForceTriggersTask()
|
|
elif task_name.lower() == "service_score_recalc":
|
|
task_instance = ServiceScoreRecalculationTask()
|
|
else:
|
|
self.logger.warning(f"Unknown task '{task_name}', skipping")
|
|
return
|
|
|
|
self.tasks[task_name] = {
|
|
"cron": cron_expression,
|
|
"func": task_func,
|
|
"instance": task_instance,
|
|
"args": args,
|
|
"kwargs": kwargs,
|
|
}
|
|
self.logger.info(
|
|
f"Registered task '{task_name}' with schedule: {cron_expression}"
|
|
)
|
|
|
|
def _run_task(self, task_name: str, task_info: Dict[str, Any]):
|
|
"""
|
|
Run a task on its schedule.
|
|
|
|
Args:
|
|
task_name: Name of the task.
|
|
task_info: Task information including function and schedule.
|
|
"""
|
|
self.logger.info(f"Starting scheduler for task '{task_name}'")
|
|
|
|
# Parse the cron expression
|
|
cron = croniter(task_info["cron"], datetime.now())
|
|
|
|
while not self.stop_event.is_set():
|
|
# Get the next run time
|
|
next_run = cron.get_next(datetime)
|
|
self.logger.info(f"Next run for task '{task_name}': {next_run}")
|
|
|
|
# Sleep until the next run time
|
|
now = datetime.now()
|
|
sleep_seconds = (next_run - now).total_seconds()
|
|
|
|
if sleep_seconds > 0:
|
|
# Wait until next run time or until stop event is set
|
|
if self.stop_event.wait(sleep_seconds):
|
|
break
|
|
|
|
# Run the task if we haven't been stopped
|
|
if not self.stop_event.is_set():
|
|
try:
|
|
self.logger.info(f"Running task '{task_name}'")
|
|
# Use task instance as a context manager to ensure
|
|
# a single database connection is used for the entire task
|
|
with task_info["instance"]:
|
|
# Execute the registered task function with its arguments
|
|
task_info["func"](*task_info["args"], **task_info["kwargs"])
|
|
self.logger.info(f"Task '{task_name}' completed")
|
|
except Exception as e:
|
|
self.logger.exception(f"Error running task '{task_name}': {e}")
|
|
finally:
|
|
# Close the database pool after task execution
|
|
close_db_pool()
|
|
|
|
def start(self):
|
|
"""Start the scheduler."""
|
|
if self.running:
|
|
self.logger.warning("Scheduler is already running")
|
|
return
|
|
|
|
self.logger.info("Starting scheduler")
|
|
self.running = True
|
|
self.stop_event.clear()
|
|
|
|
# Start a thread for each task
|
|
for task_name, task_info in self.tasks.items():
|
|
thread = threading.Thread(
|
|
target=self._run_task,
|
|
args=(task_name, task_info),
|
|
name=f"scheduler-{task_name}",
|
|
)
|
|
thread.daemon = True
|
|
thread.start()
|
|
self.threads.append(thread)
|
|
|
|
self.logger.info(f"Started {len(self.threads)} scheduler threads")
|
|
|
|
def stop(self):
|
|
"""Stop the scheduler."""
|
|
if not self.running:
|
|
return
|
|
|
|
self.logger.info("Stopping scheduler")
|
|
self.running = False
|
|
self.stop_event.set()
|
|
|
|
# Wait for all threads to terminate
|
|
for thread in self.threads:
|
|
thread.join(timeout=5.0)
|
|
|
|
self.threads = []
|
|
|
|
# Close database pool when the scheduler stops
|
|
close_db_pool()
|
|
|
|
self.logger.info("Scheduler stopped")
|
|
|
|
def is_running(self) -> bool:
|
|
"""Check if the scheduler is running."""
|
|
return self.running
|