Source code for mesa_llm.parallel_stepping

"""
Automatic parallel stepping for Mesa-LLM simulations.
"""

from __future__ import annotations

import asyncio
import concurrent.futures
import logging
from typing import TYPE_CHECKING

from mesa.agent import Agent, AgentSet

if TYPE_CHECKING:
    from .llm_agent import LLMAgent

logger = logging.getLogger(__name__)

# Global variable to control parallel stepping mode
_PARALLEL_STEPPING_MODE = "asyncio"  # or "threading"


[docs] async def step_agents_parallel(agents: list[Agent | LLMAgent]) -> None: """Step all agents in parallel using async/await.""" tasks = [] for agent in agents: if hasattr(agent, "astep"): tasks.append(agent.astep()) elif hasattr(agent, "step"): tasks.append(_sync_step(agent)) await asyncio.gather(*tasks)
async def _sync_step(agent: Agent) -> None: """Run synchronous step in async context.""" agent.step()
[docs] def step_agents_multithreaded(agents: list[Agent | LLMAgent]) -> None: """Step all agents in parallel using threads.""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] for agent in agents: if hasattr(agent, "astep"): # run async steps in the event loop in a thread futures.append( executor.submit(lambda agent=agent: asyncio.run(agent.astep())) ) elif hasattr(agent, "step"): futures.append(executor.submit(agent.step)) for future in futures: future.result()
[docs] def step_agents_parallel_sync(agents: list[Agent | LLMAgent]) -> None: """Synchronous wrapper for parallel stepping using the global mode.""" if _PARALLEL_STEPPING_MODE == "asyncio": try: asyncio.get_running_loop() # If in event loop, use thread with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit( lambda: asyncio.run(step_agents_parallel(agents)) ) future.result() except RuntimeError: # No event loop - create one asyncio.run(step_agents_parallel(agents)) elif _PARALLEL_STEPPING_MODE == "threading": step_agents_multithreaded(agents) else: raise ValueError(f"Unknown parallel stepping mode: {_PARALLEL_STEPPING_MODE}")
# Patch Mesa's shuffle_do for automatic parallel detection _original_shuffle_do = AgentSet.shuffle_do def _enhanced_shuffle_do(self, method: str, *args, **kwargs): """Enhanced shuffle_do with automatic parallel stepping.""" if method == "step" and self: agent = next(iter(self)) if hasattr(agent, "model") and getattr(agent.model, "parallel_stepping", False): step_agents_parallel_sync(list(self)) return _original_shuffle_do(self, method, *args, **kwargs)
[docs] def enable_automatic_parallel_stepping(mode: str = "asyncio"): """Enable automatic parallel stepping with selectable mode ('asyncio' or 'threading').""" global _PARALLEL_STEPPING_MODE # noqa: PLW0603 if mode not in ("asyncio", "threading"): raise ValueError("mode must be either 'asyncio' or 'threading'") _PARALLEL_STEPPING_MODE = mode AgentSet.shuffle_do = _enhanced_shuffle_do
[docs] def disable_automatic_parallel_stepping(): """Restore original shuffle_do behavior.""" AgentSet.shuffle_do = _original_shuffle_do
# --- Monkey-patch AgentSet with do_async for async parallel method calls --- def _agentset_do_async(self, method: str, *args, **kwargs): """ Call the given async method on all agents in the set in parallel. Usage: await agents.do_async("async_function") """ logger.info("Running async method '%s' on %d agents", method, len(self)) async def _run(): tasks = [] for agent in self: fn = getattr(agent, method, None) if fn is not None and asyncio.iscoroutinefunction(fn): tasks.append(fn(*args, **kwargs)) else: raise AttributeError( f"Agent {agent} does not have async method '{method}'" ) return await asyncio.gather(*tasks) return _run() AgentSet.do_async = _agentset_do_async