From c27bb2dabeb8f8ba51d3769ea842a0dddf164b3b Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 19 Sep 2024 16:59:00 +0200 Subject: [PATCH 1/6] datacollector: Allow collecting data from Agent type Added `agenttype_reporters` to Mesa's DataCollector, enabling collection of data specific to agent types. --- mesa/datacollection.py | 143 +++++++++++++++++++++++++++++++++++------ mesa/model.py | 3 + 2 files changed, 127 insertions(+), 19 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index bf50be2a723..052700291bf 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -1,18 +1,19 @@ """Mesa Data Collection Module. DataCollector is meant to provide a simple, standard way to collect data -generated by a Mesa model. It collects three types of data: model-level data, -agent-level data, and tables. +generated by a Mesa model. It collects four types of data: model-level data, +agent-level data, agent-type-level data, and tables. -A DataCollector is instantiated with two dictionaries of reporter names and -associated variable names or functions for each, one for model-level data and -one for agent-level data; a third dictionary provides table names and columns. -Variable names are converted into functions which retrieve attributes of that -name. +A DataCollector is instantiated with three dictionaries of reporter names and +associated variable names or functions for each, one for model-level data, +one for agent-level data, and one for agent-type-level data; a fourth dictionary +provides table names and columns. Variable names are converted into functions +which retrieve attributes of that name. When the collect() method is called, each model-level function is called, with the model as the argument, and the results associated with the relevant -variable. Then the agent-level functions are called on each agent. +variable. Then the agent-level functions are called on each agent, and the +agent-type-level functions are called on each agent of the specified type. Additionally, other objects can write directly to tables by passing in an appropriate dictionary object for a table row. @@ -21,19 +22,23 @@ * model_vars maps each reporter to a list of its values * tables maps each table to a dictionary, with each column as a key with a list as its value. - * _agent_records maps each model step to a list of each agents id + * _agent_records maps each model step to a list of each agent's id and its values. + * _agenttype_records maps each model step to a dictionary of agent types, + each containing a list of each agent's id and its values. Finally, DataCollector can create a pandas DataFrame from each collection. The default DataCollector here makes several assumptions: * The model has an agent list called agents + * The model has a dictionary of AgentSets called agents_by_type * For collecting agent-level variables, agents must have a unique_id """ import contextlib import itertools import types +import warnings from copy import deepcopy from functools import partial @@ -44,24 +49,25 @@ class DataCollector: """Class for collecting data generated by a Mesa model. - A DataCollector is instantiated with dictionaries of names of model- and - agent-level variables to collect, associated with attribute names or - functions which actually collect them. When the collect(...) method is - called, it collects these attributes and executes these functions one by - one and stores the results. + A DataCollector is instantiated with dictionaries of names of model-, + agent-, and agent-type-level variables to collect, associated with + attribute names or functions which actually collect them. When the + collect(...) method is called, it collects these attributes and executes + these functions one by one and stores the results. """ def __init__( self, model_reporters=None, agent_reporters=None, + agenttype_reporters=None, tables=None, ): - """Instantiate a DataCollector with lists of model and agent reporters. + """Instantiate a DataCollector with lists of model, agent, and agent-type reporters. - Both model_reporters and agent_reporters accept a dictionary mapping a - variable name to either an attribute name, a function, a method of a class/instance, - or a function with parameters placed in a list. + Both model_reporters, agent_reporters, and agenttype_reporters accept a + dictionary mapping a variable name to either an attribute name, a function, + a method of a class/instance, or a function with parameters placed in a list. Model reporters can take four types of arguments: 1. Lambda function: @@ -85,6 +91,10 @@ def __init__( 4. Functions with parameters placed in a list: {"Agent_Function": [function, [param_1, param_2]]} + Agenttype reporters take a dictionary mapping agent types to dictionaries + of reporter names and attributes/funcs/methods, similar to agent_reporters: + {Wolf: {"energy": lambda a: a.energy}} + The tables arg accepts a dictionary mapping names of tables to lists of columns. For example, if we want to allow agents to write their age when they are destroyed (to keep track of lifespans), it might look @@ -94,6 +104,8 @@ def __init__( Args: model_reporters: Dictionary of reporter names and attributes/funcs/methods. agent_reporters: Dictionary of reporter names and attributes/funcs/methods. + agenttype_reporters: Dictionary of agent types to dictionaries of + reporter names and attributes/funcs/methods. tables: Dictionary of table names to lists of column names. Notes: @@ -103,9 +115,11 @@ def __init__( """ self.model_reporters = {} self.agent_reporters = {} + self.agenttype_reporters = {} self.model_vars = {} self._agent_records = {} + self._agenttype_records = {} self.tables = {} if model_reporters is not None: @@ -116,6 +130,11 @@ def __init__( for name, reporter in agent_reporters.items(): self._new_agent_reporter(name, reporter) + if agenttype_reporters is not None: + for agent_type, reporters in agenttype_reporters.items(): + for name, reporter in reporters.items(): + self._new_agenttype_reporter(agent_type, name, reporter) + if tables is not None: for name, columns in tables.items(): self._new_table(name, columns) @@ -163,6 +182,38 @@ def func_with_params(agent): self.agent_reporters[name] = reporter + def _new_agenttype_reporter(self, agent_type, name, reporter): + """Add a new agent-type-level reporter to collect. + + Args: + agent_type: The type of agent to collect data for. + name: Name of the agent-type-level variable to collect. + reporter: Attribute string, function object, method of a class/instance, or + function with parameters placed in a list that returns the + variable when given an agent instance. + """ + if agent_type not in self.agenttype_reporters: + self.agenttype_reporters[agent_type] = {} + + # Use the same logic as _new_agent_reporter + if isinstance(reporter, str): + attribute_name = reporter + + def attr_reporter(agent): + return getattr(agent, attribute_name, None) + + reporter = attr_reporter + + elif isinstance(reporter, list): + func, params = reporter[0], reporter[1] + + def func_with_params(agent): + return func(agent, *params) + + reporter = func_with_params + + self.agenttype_reporters[agent_type][name] = reporter + def _new_table(self, table_name, table_columns): """Add a new table that objects can write to. @@ -190,6 +241,21 @@ def get_reports(agent): ) return agent_records + def _record_agenttype(self, model, agent_type): + """Record agent-type data in a mapping of functions and agents.""" + rep_funcs = self.agenttype_reporters[agent_type].values() + + def get_reports(agent): + _prefix = (agent.model.steps, agent.unique_id) + reports = tuple(rep(agent) for rep in rep_funcs) + return _prefix + reports + + agenttype_records = map( + get_reports, + model.agents_by_type[agent_type], + ) + return agenttype_records + def collect(self, model): """Collect all the data for the given model object.""" if self.model_reporters: @@ -208,7 +274,6 @@ def collect(self, model): elif isinstance(reporter, list): self.model_vars[var].append(deepcopy(reporter[0](*reporter[1]))) # Assume it's a callable otherwise (e.g., method) - # TODO: Check if method of a class explicitly else: self.model_vars[var].append(deepcopy(reporter())) @@ -216,6 +281,14 @@ def collect(self, model): agent_records = self._record_agents(model) self._agent_records[model.steps] = list(agent_records) + if self.agenttype_reporters: + self._agenttype_records[model.steps] = {} + for agent_type in self.agenttype_reporters: + agenttype_records = self._record_agenttype(model, agent_type) + self._agenttype_records[model.steps][agent_type] = list( + agenttype_records + ) + def add_table_row(self, table_name, row, ignore_missing=False): """Add a row dictionary to a specific table. @@ -272,6 +345,38 @@ def get_agent_vars_dataframe(self): ) return df + def get_agenttype_vars_dataframe(self, agent_type): + """Create a pandas DataFrame from the agent-type variables for a specific agent type. + + The DataFrame has one column for each variable, with two additional + columns for tick and agent_id. + + Args: + agent_type: The type of agent to get the data for. + """ + # Check if self.agenttype_reporters dictionary is empty for this agent type, if so return empty DataFrame + if agent_type not in self.agenttype_reporters: + warnings.warn( + f"No agent-type reporters have been defined for {agent_type} in the DataCollector, returning empty DataFrame.", + UserWarning, + stacklevel=2, + ) + return pd.DataFrame() + + all_records = itertools.chain.from_iterable( + records[agent_type] + for records in self._agenttype_records.values() + if agent_type in records + ) + rep_names = list(self.agenttype_reporters[agent_type]) + + df = pd.DataFrame.from_records( + data=all_records, + columns=["Step", "AgentID", *rep_names], + index=["Step", "AgentID"], + ) + return df + def get_table_dataframe(self, table_name): """Create a pandas DataFrame from a particular table. diff --git a/mesa/model.py b/mesa/model.py index 1302c60e482..a3227c10cdc 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -198,6 +198,7 @@ def initialize_data_collector( self, model_reporters=None, agent_reporters=None, + agenttype_reporters=None, tables=None, ) -> None: """Initialize the data collector for the model. @@ -205,6 +206,7 @@ def initialize_data_collector( Args: model_reporters: model reporters to collect agent_reporters: agent reporters to collect + agenttype_reporters: agent type reporters to collect tables: tables to collect """ @@ -219,6 +221,7 @@ def initialize_data_collector( self.datacollector = DataCollector( model_reporters=model_reporters, agent_reporters=agent_reporters, + agenttype_reporters=agenttype_reporters, tables=tables, ) # Collect data for the first time during initialization. From 8a75fe108761321feaefa1e83d3330d52c16367d Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 19 Sep 2024 16:59:08 +0200 Subject: [PATCH 2/6] Add tests for agenttype_reporters --- tests/test_datacollector.py | 117 ++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 88792554737..57bfe4019b0 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -3,6 +3,7 @@ import unittest from mesa import Agent, Model +from mesa.datacollection import DataCollector from mesa.time import BaseScheduler @@ -28,6 +29,30 @@ def write_final_values(self): # D103 self.model.datacollector.add_table_row("Final_Values", row) +class MockAgentA(MockAgent): + """Agent subclass A for testing agent-type-specific reporters.""" + + def __init__(self, model, val=0): # noqa: D107 + super().__init__(model, val) + self.type_a_val = val * 2 + + def step(self): # noqa: D102 + super().step() + self.type_a_val = self.val * 2 + + +class MockAgentB(MockAgent): + """Agent subclass B for testing agent-type-specific reporters.""" + + def __init__(self, model, val=0): # noqa: D107 + super().__init__(model, val) + self.type_b_val = val * 3 + + def step(self): # noqa: D102 + super().step() + self.type_b_val = self.val * 3 + + def agent_function_with_params(agent, multiplier, offset): # noqa: D103 return (agent.val * multiplier) + offset @@ -74,6 +99,38 @@ def step(self): # noqa: D102 self.datacollector.collect(self) +class MockModelWithAgentTypes(Model): + """Model for testing agent-type-specific reporters.""" + + def __init__(self): # noqa: D107 + super().__init__() + self.schedule = BaseScheduler(self) + self.model_val = 100 + + for i in range(10): + if i % 2 == 0: + self.schedule.add(MockAgentA(self, val=i)) + else: + self.schedule.add(MockAgentB(self, val=i)) + + self.datacollector = DataCollector( + model_reporters={ + "total_agents": lambda m: m.schedule.get_agent_count(), + }, + agent_reporters={ + "value": lambda a: a.val, + }, + agenttype_reporters={ + MockAgentA: {"type_a_val": lambda a: a.type_a_val}, + MockAgentB: {"type_b_val": lambda a: a.type_b_val}, + }, + ) + + def step(self): # noqa: D102 + self.schedule.step() + self.datacollector.collect(self) + + class TestDataCollector(unittest.TestCase): """Tests for DataCollector.""" @@ -206,5 +263,65 @@ def test_initialize_before_agents_added_to_scheduler(self): # noqa: D102 ) +class TestDataCollectorWithAgentTypes(unittest.TestCase): + """Tests for DataCollector with agent-type-specific reporters.""" + + def setUp(self): + """Create the model and run it a set number of steps.""" + self.model = MockModelWithAgentTypes() + for _ in range(5): + self.model.step() + + def test_agenttype_vars(self): + """Test agent-type-specific variable collection.""" + data_collector = self.model.datacollector + + # Test MockAgentA data + agent_a_data = data_collector.get_agenttype_vars_dataframe(MockAgentA) + self.assertIn("type_a_val", agent_a_data.columns) + self.assertEqual(len(agent_a_data), 25) # 5 agents * 5 steps + for (step, agent_id), value in agent_a_data["type_a_val"].items(): + expected_value = (agent_id - 1) * 2 + step * 2 + self.assertEqual(value, expected_value) + + # Test MockAgentB data + agent_b_data = data_collector.get_agenttype_vars_dataframe(MockAgentB) + self.assertIn("type_b_val", agent_b_data.columns) + self.assertEqual(len(agent_b_data), 25) # 5 agents * 5 steps + for (step, agent_id), value in agent_b_data["type_b_val"].items(): + expected_value = (agent_id - 1) * 3 + step * 3 + self.assertEqual(value, expected_value) + + def test_agenttype_and_agent_vars(self): + """Test that agent-type-specific and general agent variables are collected correctly.""" + data_collector = self.model.datacollector + + agent_vars = data_collector.get_agent_vars_dataframe() + agent_a_vars = data_collector.get_agenttype_vars_dataframe(MockAgentA) + agent_b_vars = data_collector.get_agenttype_vars_dataframe(MockAgentB) + + # Check that general agent variables are present for all agents + self.assertIn("value", agent_vars.columns) + + # Check that agent-type-specific variables are only present in their respective dataframes + self.assertIn("type_a_val", agent_a_vars.columns) + self.assertNotIn("type_a_val", agent_b_vars.columns) + self.assertIn("type_b_val", agent_b_vars.columns) + self.assertNotIn("type_b_val", agent_a_vars.columns) + + def test_nonexistent_agenttype(self): + """Test that requesting data for a non-existent agent type raises a warning.""" + data_collector = self.model.datacollector + + class NonExistentAgent(Agent): + pass + + with self.assertWarns(UserWarning): + non_existent_data = data_collector.get_agenttype_vars_dataframe( + NonExistentAgent + ) + self.assertTrue(non_existent_data.empty) + + if __name__ == "__main__": unittest.main() From 5bd3f92db575f3ca14c3242491a9bb4edd6e0661 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Fri, 20 Sep 2024 08:42:58 +0200 Subject: [PATCH 3/6] Added three new test methods For AgentTyoe Added three new test methods to cover the missing codepaths: 1. `test_agenttype_reporter_string_attribute`: This test covers the case where the reporter is a string (attribute name). 2. `test_agenttype_reporter_function_with_params`: This test covers the case where the reporter is a list (function with parameters). 3. `test_agenttype_reporter_multiple_types`: This test explicitly checks that adding reporters for multiple agent types works correctly, which covers the case where `agent_type` is not initially in `self.agenttype_reporters`. --- tests/test_datacollector.py | 49 +++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 57bfe4019b0..66b1782d26f 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -322,6 +322,55 @@ class NonExistentAgent(Agent): ) self.assertTrue(non_existent_data.empty) + def test_agenttype_reporter_string_attribute(self): + """Test agent-type-specific reporter with string attribute.""" + model = MockModel() + model.datacollector._new_agenttype_reporter(MockAgentA, "string_attr", "val") + model.step() + + agent_a_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentA) + self.assertIn("string_attr", agent_a_data.columns) + for (step, agent_id), value in agent_a_data["string_attr"].items(): + expected_value = agent_id + 1 # Initial value + 1 step + self.assertEqual(value, expected_value) + + def test_agenttype_reporter_function_with_params(self): + """Test agent-type-specific reporter with function and parameters.""" + + def test_func(agent, multiplier): + return agent.val * multiplier + + model = MockModel() + model.datacollector._new_agenttype_reporter( + MockAgentB, "func_param", [test_func, [2]] + ) + model.step() + + agent_b_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentB) + self.assertIn("func_param", agent_b_data.columns) + for (step, agent_id), value in agent_b_data["func_param"].items(): + expected_value = (agent_id + 1) * 2 # (Initial value + 1 step) * 2 + self.assertEqual(value, expected_value) + + def test_agenttype_reporter_multiple_types(self): + """Test adding reporters for multiple agent types.""" + model = MockModel() + model.datacollector._new_agenttype_reporter( + MockAgentA, "type_a_val", lambda a: a.type_a_val + ) + model.datacollector._new_agenttype_reporter( + MockAgentB, "type_b_val", lambda a: a.type_b_val + ) + model.step() + + agent_a_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentA) + agent_b_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentB) + + self.assertIn("type_a_val", agent_a_data.columns) + self.assertIn("type_b_val", agent_b_data.columns) + self.assertNotIn("type_b_val", agent_a_data.columns) + self.assertNotIn("type_a_val", agent_b_data.columns) + if __name__ == "__main__": unittest.main() From 86b184dc341ee1f6578aefbbda86e159f27ee4a5 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Fri, 20 Sep 2024 10:06:26 +0200 Subject: [PATCH 4/6] Datacollection: Update _record_agenttype for Agent subclasses --- mesa/datacollection.py | 22 ++++++++++++++++++---- tests/test_datacollector.py | 37 ++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 052700291bf..2156fb462cf 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -250,10 +250,24 @@ def get_reports(agent): reports = tuple(rep(agent) for rep in rep_funcs) return _prefix + reports - agenttype_records = map( - get_reports, - model.agents_by_type[agent_type], - ) + agent_types = model.agent_types + if agent_type in agent_types: + agents = model.agents_by_type[agent_type] + else: + from mesa import Agent + + # Check if agent_type is an Agent subclass + if issubclass(agent_type, Agent): + raise NotImplementedError( + f"Agent type {agent_type} is not in model.agent_types. We might implement using superclasses in the future. For now, use one of {agent_types}." + ) + else: + # Raise error if agent_type is not in model.agent_types + raise ValueError( + f"Agent type {agent_type} is not recognized as an Agent type in the model. Use one of {agent_types}." + ) + + agenttype_records = map(get_reports, agents) return agenttype_records def collect(self, model): diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 66b1782d26f..e8aba12372c 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -104,22 +104,17 @@ class MockModelWithAgentTypes(Model): def __init__(self): # noqa: D107 super().__init__() - self.schedule = BaseScheduler(self) self.model_val = 100 for i in range(10): if i % 2 == 0: - self.schedule.add(MockAgentA(self, val=i)) + MockAgentA(self, val=i) else: - self.schedule.add(MockAgentB(self, val=i)) + MockAgentB(self, val=i) self.datacollector = DataCollector( - model_reporters={ - "total_agents": lambda m: m.schedule.get_agent_count(), - }, - agent_reporters={ - "value": lambda a: a.val, - }, + model_reporters={"total_agents": lambda m: len(m.agents)}, + agent_reporters={"value": lambda a: a.val}, agenttype_reporters={ MockAgentA: {"type_a_val": lambda a: a.type_a_val}, MockAgentB: {"type_b_val": lambda a: a.type_b_val}, @@ -127,7 +122,7 @@ def __init__(self): # noqa: D107 ) def step(self): # noqa: D102 - self.schedule.step() + self.agents.do("step") self.datacollector.collect(self) @@ -324,14 +319,14 @@ class NonExistentAgent(Agent): def test_agenttype_reporter_string_attribute(self): """Test agent-type-specific reporter with string attribute.""" - model = MockModel() + model = MockModelWithAgentTypes() model.datacollector._new_agenttype_reporter(MockAgentA, "string_attr", "val") model.step() agent_a_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentA) self.assertIn("string_attr", agent_a_data.columns) - for (step, agent_id), value in agent_a_data["string_attr"].items(): - expected_value = agent_id + 1 # Initial value + 1 step + for (_step, agent_id), value in agent_a_data["string_attr"].items(): + expected_value = agent_id self.assertEqual(value, expected_value) def test_agenttype_reporter_function_with_params(self): @@ -340,7 +335,7 @@ def test_agenttype_reporter_function_with_params(self): def test_func(agent, multiplier): return agent.val * multiplier - model = MockModel() + model = MockModelWithAgentTypes() model.datacollector._new_agenttype_reporter( MockAgentB, "func_param", [test_func, [2]] ) @@ -348,13 +343,13 @@ def test_func(agent, multiplier): agent_b_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentB) self.assertIn("func_param", agent_b_data.columns) - for (step, agent_id), value in agent_b_data["func_param"].items(): - expected_value = (agent_id + 1) * 2 # (Initial value + 1 step) * 2 + for (_step, agent_id), value in agent_b_data["func_param"].items(): + expected_value = agent_id * 2 self.assertEqual(value, expected_value) def test_agenttype_reporter_multiple_types(self): """Test adding reporters for multiple agent types.""" - model = MockModel() + model = MockModelWithAgentTypes() model.datacollector._new_agenttype_reporter( MockAgentA, "type_a_val", lambda a: a.type_a_val ) @@ -371,6 +366,14 @@ def test_agenttype_reporter_multiple_types(self): self.assertNotIn("type_b_val", agent_a_data.columns) self.assertNotIn("type_a_val", agent_b_data.columns) + def test_agenttype_reporter_not_in_model(self): + """Test NotImplementedError is raised when agent type is not in model.agents_by_type.""" + model = MockModelWithAgentTypes() + # MockAgent is a legit Agent subclass, but it is not in model.agents_by_type + model.datacollector._new_agenttype_reporter(MockAgent, "val", lambda a: a.val) + with self.assertRaises(NotImplementedError): + model.step() + if __name__ == "__main__": unittest.main() From 3dbb5b6ab1a62e245bdaa700eacd6a368e32b418 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Fri, 20 Sep 2024 10:17:02 +0200 Subject: [PATCH 5/6] Allow all Agent subclasses --- mesa/datacollection.py | 9 ++++----- tests/test_datacollector.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 2156fb462cf..fb9333da898 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -256,15 +256,14 @@ def get_reports(agent): else: from mesa import Agent - # Check if agent_type is an Agent subclass if issubclass(agent_type, Agent): - raise NotImplementedError( - f"Agent type {agent_type} is not in model.agent_types. We might implement using superclasses in the future. For now, use one of {agent_types}." - ) + agents = [ + agent for agent in model.agents if isinstance(agent, agent_type) + ] else: # Raise error if agent_type is not in model.agent_types raise ValueError( - f"Agent type {agent_type} is not recognized as an Agent type in the model. Use one of {agent_types}." + f"Agent type {agent_type} is not recognized as an Agent type in the model or Agent subclass. Use an Agent (sub)class, like {agent_types}." ) agenttype_records = map(get_reports, agents) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index e8aba12372c..b2760fc1b44 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -366,14 +366,22 @@ def test_agenttype_reporter_multiple_types(self): self.assertNotIn("type_b_val", agent_a_data.columns) self.assertNotIn("type_a_val", agent_b_data.columns) - def test_agenttype_reporter_not_in_model(self): - """Test NotImplementedError is raised when agent type is not in model.agents_by_type.""" + def test_agenttype_superclass_reporter(self): + """Test adding a reporter for a superclass of an agent type.""" model = MockModelWithAgentTypes() - # MockAgent is a legit Agent subclass, but it is not in model.agents_by_type model.datacollector._new_agenttype_reporter(MockAgent, "val", lambda a: a.val) - with self.assertRaises(NotImplementedError): + model.datacollector._new_agenttype_reporter(Agent, "val", lambda a: a.val) + for _ in range(3): model.step() + super_data = model.datacollector.get_agenttype_vars_dataframe(MockAgent) + agent_data = model.datacollector.get_agenttype_vars_dataframe(Agent) + self.assertIn("val", super_data.columns) + self.assertIn("val", agent_data.columns) + self.assertEqual(len(super_data), 30) # 10 agents * 3 steps + self.assertEqual(len(agent_data), 30) + self.assertTrue(super_data.equals(agent_data)) + if __name__ == "__main__": unittest.main() From fdbdd4a2db48c95c72ef8233519d742641dd8172 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Sat, 21 Sep 2024 09:19:20 +0200 Subject: [PATCH 6/6] DataCollector: Remove assumptions All three assumptions are now guarded by that we require the Agent and Model super classes to always be initialized. So they are not relevant anymore for the user. --- mesa/datacollection.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index fb9333da898..2fb4d241c4b 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -28,11 +28,6 @@ each containing a list of each agent's id and its values. Finally, DataCollector can create a pandas DataFrame from each collection. - -The default DataCollector here makes several assumptions: - * The model has an agent list called agents - * The model has a dictionary of AgentSets called agents_by_type - * For collecting agent-level variables, agents must have a unique_id """ import contextlib