206 lines
6.6 KiB
Python
206 lines
6.6 KiB
Python
import uuid
|
|
import queue
|
|
import threading
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
|
from typing import Callable, Any, Iterable
|
|
|
|
|
|
class WorkerManager:
|
|
"""Simple worker manager with UI queue dispatch.
|
|
|
|
- Uses ThreadPoolExecutor for IO-bound tasks by default.
|
|
- Can also use ProcessPoolExecutor for CPU-bound tasks when `kind='process'`.
|
|
- Exposes a `ui_queue` (thread-safe `queue.Queue`) that the GUI should poll.
|
|
- Supports `submit` (single task) and `map_iterable` (parallel map with per-item progress).
|
|
"""
|
|
|
|
def __init__(self, max_threads: int = None, max_processes: int = None):
|
|
self._threads = ThreadPoolExecutor(max_workers=max_threads)
|
|
self._processes = ProcessPoolExecutor(max_workers=max_processes)
|
|
self.ui_queue = queue.Queue()
|
|
self._tasks = {} # task_id -> metadata (callbacks, futures)
|
|
self._lock = threading.Lock()
|
|
|
|
def _new_task_id(self) -> str:
|
|
return uuid.uuid4().hex
|
|
|
|
def submit(
|
|
self,
|
|
func: Callable,
|
|
*args,
|
|
kind: str = "thread",
|
|
on_done: Callable[[Any], None] | None = None,
|
|
on_error: Callable[[Exception], None] | None = None
|
|
) -> str:
|
|
"""Submit a callable to the chosen executor. Returns a task_id."""
|
|
task_id = self._new_task_id()
|
|
executor = self._threads if kind == "thread" else self._processes
|
|
|
|
def _wrap():
|
|
try:
|
|
res = func(*args)
|
|
self.ui_queue.put(("done", task_id, res))
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
self.ui_queue.put(("error", task_id, tb))
|
|
|
|
# announce started
|
|
self.ui_queue.put(
|
|
(
|
|
"started",
|
|
task_id,
|
|
{"func": getattr(func, "__name__", str(func)), "args": len(args)},
|
|
)
|
|
)
|
|
|
|
fut = executor.submit(_wrap)
|
|
with self._lock:
|
|
self._tasks[task_id] = {
|
|
"futures": [fut],
|
|
"on_done": on_done,
|
|
"on_error": on_error,
|
|
}
|
|
return task_id
|
|
|
|
def map_iterable(
|
|
self,
|
|
func: Callable,
|
|
items: Iterable,
|
|
kind: str = "thread",
|
|
on_progress: Callable[[Any], None] | None = None,
|
|
on_done: Callable[[list], None] | None = None,
|
|
) -> str:
|
|
"""Map func over items in parallel and emit per-item progress messages to ui_queue.
|
|
|
|
The GUI should poll `ui_queue` and call `dispatch_message` to run callbacks in main thread.
|
|
"""
|
|
task_id = self._new_task_id()
|
|
executor = self._threads if kind == "thread" else self._processes
|
|
|
|
items = list(items)
|
|
if not items:
|
|
# Immediately finish
|
|
self.ui_queue.put(("done", task_id, []))
|
|
return task_id
|
|
|
|
futures = []
|
|
|
|
def _submit_item(item_to_process):
|
|
# Use default argument to capture current value (avoid late binding issue)
|
|
def _call(item=item_to_process):
|
|
return func(item)
|
|
|
|
return executor.submit(_call)
|
|
|
|
for it in items:
|
|
fut = _submit_item(it)
|
|
futures.append((it, fut))
|
|
|
|
with self._lock:
|
|
self._tasks[task_id] = {
|
|
"futures": [f for _, f in futures],
|
|
"on_progress": on_progress,
|
|
"on_done": on_done,
|
|
}
|
|
|
|
# announce started (with estimated total)
|
|
self.ui_queue.put(
|
|
(
|
|
"started",
|
|
task_id,
|
|
{"func": getattr(func, "__name__", str(func)), "total": len(items)},
|
|
)
|
|
)
|
|
|
|
# Start a watcher thread that collects results and pushes to ui_queue
|
|
def _watcher():
|
|
results = []
|
|
for it, fut in futures:
|
|
try:
|
|
res = fut.result()
|
|
results.append(res)
|
|
# emit progress message
|
|
self.ui_queue.put(("progress", task_id, res))
|
|
except Exception:
|
|
tb = traceback.format_exc()
|
|
self.ui_queue.put(("error", task_id, tb))
|
|
# all done
|
|
self.ui_queue.put(("done", task_id, results))
|
|
|
|
threading.Thread(target=_watcher, daemon=True).start()
|
|
return task_id
|
|
|
|
def dispatch_message(self, msg: tuple):
|
|
"""Dispatch a single message (called from GUI/main thread)."""
|
|
typ, task_id, payload = msg
|
|
meta = None
|
|
with self._lock:
|
|
meta = self._tasks.get(task_id)
|
|
if not meta:
|
|
return
|
|
if typ == "progress":
|
|
cb = meta.get("on_progress")
|
|
if cb:
|
|
try:
|
|
cb(payload)
|
|
except Exception:
|
|
# swallow exceptions from UI callbacks
|
|
pass
|
|
elif typ == "done":
|
|
cb = meta.get("on_done")
|
|
if cb:
|
|
try:
|
|
cb(payload)
|
|
except Exception:
|
|
pass
|
|
# cleanup
|
|
with self._lock:
|
|
self._tasks.pop(task_id, None)
|
|
elif typ == "started":
|
|
# no-op for dispatch — GUI may handle started notifications separately
|
|
return
|
|
elif typ == "error":
|
|
cb = meta.get("on_error") or meta.get("on_done")
|
|
if cb:
|
|
try:
|
|
cb(payload)
|
|
except Exception:
|
|
pass
|
|
with self._lock:
|
|
self._tasks.pop(task_id, None)
|
|
|
|
def cancel(self, task_id: str) -> bool:
|
|
"""Attempt to cancel a task. Returns True if cancellation was requested.
|
|
|
|
This will call `cancel()` on futures where supported. For processes this
|
|
may not immediately terminate the work.
|
|
"""
|
|
with self._lock:
|
|
meta = self._tasks.get(task_id)
|
|
if not meta:
|
|
return False
|
|
futs = meta.get("futures", [])
|
|
cancelled_any = False
|
|
for f in futs:
|
|
try:
|
|
ok = f.cancel()
|
|
cancelled_any = cancelled_any or ok
|
|
except Exception:
|
|
pass
|
|
# inform UI
|
|
self.ui_queue.put(("cancelled", task_id, None))
|
|
with self._lock:
|
|
self._tasks.pop(task_id, None)
|
|
return cancelled_any
|
|
|
|
def shutdown(self):
|
|
try:
|
|
self._threads.shutdown(wait=False)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
self._processes.shutdown(wait=False)
|
|
except Exception:
|
|
pass
|