talemate/tests/test_nodes.py
veguAI 483e07fd48
Nodegraph (#7)
Nodegraph refactor
2025-04-03 01:03:28 +03:00

387 lines
No EOL
12 KiB
Python

from talemate.game.engine.nodes.core import Node, Graph, Socket, GraphState, Loop, Entry, Router, GraphContext
import networkx as nx
import structlog
import pytest
from talemate.util.async_tools import cleanup_pending_tasks
log = structlog.get_logger()
class Counter(Node):
def __init__(self, title="Counter", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_output("value")
self.set_property("counter", 0)
async def run(self, state: GraphState):
counter = self.get_property("counter")
self.set_output_values({
"value": counter
})
self.set_property("counter", counter + 1, state)
@pytest.mark.asyncio
async def test_simple_graph():
# Create nodes
node_a = Node(title="A")
node_b = Node(title="B")
node_c = Node(title="C")
node_d = Node(title="D")
# Add sockets to nodes
out_a1 = node_a.add_output("out1")
out_a2 = node_a.add_output("out2")
in_b = node_b.add_input("in")
out_b = node_b.add_output("out")
in_c = node_c.add_input("in")
out_c = node_c.add_output("out")
in_d1 = node_d.add_input("in1")
in_d2 = node_d.add_input("in2")
# Create graph
graph = Graph()
graph.add_node(node_a)
graph.add_node(node_b)
graph.add_node(node_c)
graph.add_node(node_d)
# Connect nodes via sockets
graph.connect(out_a1, in_b) # A -> B
graph.connect(out_a2, in_c) # A -> C
graph.connect(out_b, in_d1) # B -> D
graph.connect(out_c, in_d2) # C -> D
nxgraph = graph.build()
# Print paths
print([graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id)])
print([graph.node(n).title for n in nx.topological_sort(nxgraph)])
# Add assertions for expected behavior
shortest_path = [graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id)]
topo_sort = [graph.node(n).title for n in nx.topological_sort(nxgraph)]
assert len(shortest_path) == 3, "Shortest path should have 3 nodes"
assert shortest_path[0] == "A", "Path should start with A"
assert shortest_path[-1] == "D", "Path should end with D"
assert len(topo_sort) == 4, "Should have all 4 nodes in topological sort"
assert topo_sort[0] == "A", "Topological sort should start with A"
assert topo_sort[-1] == "D", "Topological sort should end with D"
await cleanup_pending_tasks()
@pytest.mark.asyncio
async def test_data_flow():
# Create nodes with specific behaviors
class NodeA(Node):
def __init__(self):
super().__init__(title="A")
self.add_output("out1")
self.add_output("out2")
async def run(self, state: GraphState):
# Output constant values for testing
self.set_output_values({
"out1": 5,
"out2": 10
})
class NodeB(Node):
def __init__(self):
super().__init__(title="B")
self.add_input("in")
self.add_output("out")
async def run(self, state: GraphState):
inputs = self.get_input_values()
# Double the input value
self.set_output_values({
"out": inputs["in"] * 2
})
class NodeC(Node):
def __init__(self):
super().__init__(title="C")
self.add_input("in")
self.add_output("out")
async def run(self, state: GraphState):
inputs = self.get_input_values()
# Add 1 to the input value
self.set_output_values({
"out": inputs["in"] + 1
})
class NodeD(Node):
result: int = 0
def __init__(self):
super().__init__(title="D")
self.add_input("in1")
self.add_input("in2")
async def run(self, state: GraphState):
inputs = self.get_input_values()
# Store sum for testing
self.result = inputs["in1"] + inputs["in2"]
# Create nodes
node_a = NodeA()
node_b = NodeB()
node_c = NodeC()
node_d = NodeD()
# Create graph
graph = Graph()
graph.add_node(node_a)
graph.add_node(node_b)
graph.add_node(node_c)
graph.add_node(node_d)
# Connect nodes via sockets
graph.connect(node_a.outputs[0], node_b.inputs[0]) # A.out1 -> B.in
graph.connect(node_a.outputs[1], node_c.inputs[0]) # A.out2 -> C.in
graph.connect(node_b.outputs[0], node_d.inputs[0]) # B.out -> D.in1
graph.connect(node_c.outputs[0], node_d.inputs[1]) # C.out -> D.in2
async def assert_state(state: GraphState):
print(state.data)
# Test data flow
# NodeA outputs: out1=5, out2=10
assert node_a.outputs[0].value == 5, "NodeA out1 should be 5"
assert node_a.outputs[1].value == 10, "NodeA out2 should be 10"
# NodeB doubles input: 5 * 2 = 10
assert node_b.outputs[0].value == 10, "NodeB should double input value"
# NodeC adds 1: 10 + 1 = 11
assert node_c.outputs[0].value == 11, "NodeC should add 1 to input value"
# NodeD sums inputs: 10 + 11 = 21
assert node_d.result == 21, "NodeD should sum its inputs"
# Execute graph
graph.callbacks.append(assert_state)
await graph.execute()
await cleanup_pending_tasks()
@pytest.mark.asyncio
async def test_property_flow():
# Create nodes with property-driven behaviors
class NumberSource(Node):
def __init__(self):
super().__init__(title="Source")
self.add_output("value")
# Set default property
self.set_property("value", 5)
async def run(self, state: GraphState):
# Output property value
self.set_output_values({
"value": self.get_property("value")
})
class Multiplier(Node):
def __init__(self):
super().__init__(title="Multiplier")
self.add_input("value")
self.add_output("result")
# Set default multiplier
self.set_property("multiplier", 2)
async def run(self, state: GraphState):
inputs = self.get_input_values()
multiplier = self.get_input_value("multiplier") # Will fall back to property
print("Multiplier input:", inputs["value"], "Multiplier:", multiplier)
self.set_output_values({
"result": (inputs["value"] or 0) * multiplier
})
class Adder(Node):
def __init__(self):
super().__init__(title="Adder")
self.add_input("value")
self.add_output("result")
# Set default addend
self.set_property("addend", 1)
async def run(self, state: GraphState):
inputs = self.get_input_values()
addend = self.get_input_value("addend") # Will fall back to property
self.set_output_values({
"result": inputs["value"] + addend
})
class Collector(Node):
result: float = 0
def __init__(self):
super().__init__(title="Collector")
self.add_input("value1")
self.add_input("value2")
# Set default values
self.set_property("value1", 0)
self.set_property("value2", 0)
async def run(self, state: GraphState):
inputs = self.get_input_values()
self.result = inputs["value1"] + inputs["value2"]
# Create nodes and graph setup...
source = NumberSource()
mult = Multiplier()
add = Adder()
collect = Collector()
# Create graph
graph = Graph()
graph.add_node(source)
graph.add_node(mult)
graph.add_node(add)
graph.add_node(collect)
# Connect nodes
graph.connect(source.outputs[0], mult.inputs[0]) # Source -> Multiplier
graph.connect(source.outputs[0], add.inputs[0]) # Source -> Adder
graph.connect(mult.outputs[0], collect.inputs[0]) # Multiplier -> Collector.value1
graph.connect(add.outputs[0], collect.inputs[1]) # Adder -> Collector.value2
async def assert_state(state:GraphState):
# Run assertions...
assert source.outputs[0].value == 5, "Source should output property value"
assert mult.outputs[0].value == 10, "Multiplier should use property multiplier"
assert add.outputs[0].value == 6, "Adder should use property addend"
assert collect.result == 16, "Collector should sum multiplier and adder outputs"
# Test property defaults
graph.callbacks.append(assert_state)
await graph.execute()
await cleanup_pending_tasks()
@pytest.mark.asyncio
async def test_simple_loop():
entry_loop = Entry()
counter = Counter()
loop = Loop(exit_condition=lambda state: counter.get_property("counter") > 10)
loop.add_node(entry_loop)
loop.add_node(counter)
loop.connect(entry_loop.outputs[0], counter.inputs[0])
entry = Entry()
graph = Graph()
graph.add_node(entry)
graph.add_node(loop)
graph.connect(entry.outputs[0], loop.inputs[0])
async def assert_state(state: GraphState):
assert counter.outputs[0].value == 10, "Counter should count to 10"
loop.callbacks.append(assert_state)
await graph.execute()
@pytest.mark.asyncio
async def test_simple_fork():
entry = Entry(title="Entry")
entry_loop = Entry(title="Entry Loop")
counter_main = Counter("CNT Main")
counter_a = Counter("CNT A")
counter_b = Counter("CNT B")
router = Router(2, selector=lambda state: 0 if counter_main.get_property("counter") % 2 == 0 else 1)
loop = Loop(title="Loop", exit_condition=lambda state: counter_main.get_property("counter") > 10)
loop.add_node(entry_loop)
loop.add_node(counter_main)
loop.add_node(counter_a)
loop.add_node(counter_b)
loop.add_node(router)
loop.connect(entry_loop.outputs[0], counter_main.inputs[0])
loop.connect(counter_main.outputs[0], router.inputs[0])
loop.connect(router.outputs[0], counter_a.inputs[0])
loop.connect(router.outputs[1], counter_b.inputs[0])
graph = Graph()
graph.add_node(entry)
graph.add_node(loop)
graph.connect(entry.outputs[0], loop.inputs[0])
async def assert_state_loop(state: GraphState):
assert counter_main.get_property("counter") == 11, "Main counter should count to 11"
assert counter_a.get_property("counter") == 5, "Counter A should count to 5"
assert counter_b.get_property("counter") == 5, "Counter B should count to 5"
loop.callbacks.append(assert_state_loop)
await graph.execute()
await cleanup_pending_tasks()
@pytest.mark.asyncio
async def test_visited_paths():
"""Test that nodes can be visited through multiple paths"""
# Create a simple graph where node A connects to B both directly and through C
# A -> B
# A -> C -> B
# Only one path gets deactivated, other should still work
graph = Graph()
# Create nodes
node_a = Node(title="Node A")
node_b = Node(title="Node B")
node_c = Node(title="Node C")
# Add nodes to graph
graph.add_node(node_a)
graph.add_node(node_b)
graph.add_node(node_c)
# Create sockets
a_out1 = node_a.add_output("out1")
a_out2 = node_a.add_output("out2")
b_in1 = node_b.add_input("in1")
b_in2 = node_b.add_input("in2")
c_in = node_c.add_input("in")
c_out = node_c.add_output("out")
# Connect nodes
# A -> B (direct path)
graph.connect(a_out1, b_in1)
# A -> C -> B (indirect path)
graph.connect(a_out2, c_in)
graph.connect(c_out, b_in2)
with GraphContext() as state:
# Deactivate the direct path
a_out1.deactivated = True
# Node A should still be available because the path through C is still active
assert node_a.check_is_available(state), "Node A should be available through path via C"
# Now deactivate the indirect path too
a_out2.deactivated = True
# Now Node A should be unavailable as all paths are deactivated
assert not node_a.check_is_available(state), "Node A should be unavailable when all paths are deactivated"
await cleanup_pending_tasks()