"""
Infrastructure Layer
This module provides utilities for parallel task execution, real-time progress
monitoring, and robust error handling in multiprocessing contexts. It is
designed for scientific computations (e.g., FTLE batch processing) where
multiple independent tasks need to be executed concurrently with visual feedback
through progress bars.
Features
--------
- Seamless integration with both terminal and Jupyter environments via `tqdm`.
- Graceful handling of errors during multiprocessing (with traceback display).
- Live monitoring of per-task and global progress using multiple progress bars.
- Clean shutdown and interruption handling across processes.
Classes
-------
ParallelExecutor
Handles parallel task execution, progress tracking, and error reporting.
Functions
---------
get_tqdm()
Detects environment and returns the appropriate tqdm class (terminal or notebook).
"""
import multiprocessing as mp
import time
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np
from colorama import Fore, Style
from colorama import init as colorama_init
from pyftle.data_source import BatchSource
colorama_init(autoreset=True)
[docs]
def get_tqdm():
"""
Return a tqdm-compatible progress bar for the current environment.
This function automatically detects whether the code is running inside a
Jupyter/IPython environment or a standard terminal, and imports the
appropriate `tqdm` variant accordingly.
Returns
-------
tqdm_class : type
The appropriate tqdm class:
- `tqdm.notebook.tqdm` if running in Jupyter/IPython.
- `tqdm.tqdm` for standard terminal environments.
Notes
-----
This ensures consistent progress bar rendering across environments.
"""
try:
from IPython.core.getipython import get_ipython
shell = get_ipython()
if shell and hasattr(shell, "config") and "IPKernelApp" in shell.config:
from tqdm.notebook import tqdm as tqdm_notebook
return tqdm_notebook
except Exception:
pass
from tqdm import tqdm as tqdm_terminal
return tqdm_terminal
tqdm = get_tqdm()
[docs]
class ParallelExecutor:
"""
Manage multiprocessing execution with live progress monitoring.
The `ParallelExecutor` is responsible for executing multiple independent
tasks in parallel using `concurrent.futures.ProcessPoolExecutor`, while
providing live progress visualization and robust error handling.
Parameters
----------
n_processes : int, optional
Number of parallel worker processes to launch (default is 4).
Attributes
----------
n_processes : int
Number of concurrent worker processes.
progress_queue : multiprocessing.Queue
Shared queue used to communicate task progress between worker processes
and the monitor process.
_stop_event : multiprocessing.Event
Event flag used to signal the monitor process to terminate.
"""
def __init__(self, n_processes: int = 4):
self.n_processes = n_processes
manager = mp.Manager()
self.progress_queue = manager.Queue()
self._stop_event = manager.Event()
def _monitor_progress(self, total_tasks: int, steps_per_task: int):
"""
Display real-time progress for all parallel tasks.
This method runs in a dedicated process and continuously listens to the
shared progress queue. It maintains:
- One global progress bar tracking total completed tasks.
- Individual progress bars for active tasks.
Parameters
----------
total_tasks : int
Total number of tasks to monitor.
steps_per_task : int
Expected number of progress updates (e.g., time steps) per task.
Notes
-----
- Each task reports its progress as `(task_id, step)` tuples to the queue.
- When a task is complete, it sends `(task_id, "done")`.
- The method terminates when all tasks are done or `_stop_event` is set.
"""
global_bar = tqdm(
total=total_tasks, desc="Global", position=0, dynamic_ncols=True
)
active_bars = {}
available_slots = list(range(1, self.n_processes + 1))
finished = 0
while not self._stop_event.is_set() and finished < total_tasks:
while not self.progress_queue.empty():
task_id, status = self.progress_queue.get()
if status == "done":
if task_id in active_bars:
pos = active_bars[task_id].pos
active_bars[task_id].close()
del active_bars[task_id]
available_slots.append(pos)
global_bar.update(1)
finished += 1
else:
if task_id not in active_bars and available_slots:
pos = available_slots.pop(0)
bar = tqdm(
total=steps_per_task,
desc=task_id,
position=pos,
leave=False,
dynamic_ncols=True,
)
active_bars[task_id] = bar
bar = active_bars[task_id]
bar.n = status
bar.refresh()
time.sleep(0.05)
global_bar.close()
for bar in active_bars.values():
bar.close()
[docs]
def run(self, tasks: list[BatchSource], worker_fn):
"""
Execute multiple tasks in parallel and collect results.
Each task is executed in a separate process via the provided `worker_fn`.
Progress updates are collected asynchronously and displayed via tqdm bars.
Parameters
----------
tasks : list of BatchSource
List of task objects to process. Each `BatchSource` must contain
attributes such as `id` and `num_steps`, representing the task's
identity and number of progress steps, respectively.
worker_fn : callable
Function with signature `worker_fn(task, queue)` that performs the
actual work for each task.
The function must:
- Report progress by placing `(task.id, step)` or `(task.id, "done")`
messages into the shared queue.
- Return a result (e.g., NumPy array) or raise an exception on failure.
Returns
-------
results : list of np.ndarray or None
List of task results in the same order as the input tasks.
Raises
------
RuntimeError
If one or more tasks fail during execution. Errors are printed with
traceback details, and all remaining tasks are immediately canceled.
Notes
-----
- Progress is displayed live for all tasks.
- If an error occurs in any task, all ongoing computations are stopped.
- The method ensures that monitor processes and worker pools terminate
cleanly even in error conditions.
"""
steps_per_task = tasks[0].num_steps # num snapshots in flow map period
monitor_proc = mp.Process(
target=self._monitor_progress,
args=(len(tasks), steps_per_task),
)
monitor_proc.start()
results: list[np.ndarray | None] = [None] * len(tasks)
exceptions = []
with ProcessPoolExecutor(max_workers=self.n_processes) as executor:
futures = {
executor.submit(worker_fn, task, self.progress_queue): i
for i, task in enumerate(tasks)
}
for future in as_completed(futures):
i = futures[future]
task = tasks[i]
try:
result = future.result()
results[i] = result # preserve task order
except Exception as e:
error_msg = (
f"\n{Fore.RED}❌ Error in task "
f"{task.id}:{Style.RESET_ALL}\n"
f"{traceback.format_exc()}"
)
print(error_msg, flush=True)
exceptions.append((task, e))
# Signal stop immediately
self._stop_event.set()
# 🔥 Cancel remaining futures and shut down pool immediately
executor.shutdown(wait=False, cancel_futures=True)
# 🔥 Kill monitor process right away
if monitor_proc.is_alive():
monitor_proc.terminate()
raise # re-raise to exit as_completed loop immediately
# Ensure monitor process is dead
if monitor_proc.is_alive():
monitor_proc.terminate()
monitor_proc.join(timeout=0.5)
if exceptions:
print(
f"{Fore.RED}\n⚠️ {len(exceptions)} task(s) failed. "
"See messages above for details."
f"{Style.RESET_ALL}",
flush=True,
)
raise RuntimeError("One or more FTLE batches failed.")
return results