mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-09-02 10:29:17 +00:00
387 lines
No EOL
12 KiB
Python
387 lines
No EOL
12 KiB
Python
from talemate.game.engine.nodes.core import Node, Graph, Socket, GraphState, Loop, Entry, Router, GraphContext
|
|
import networkx as nx
|
|
import structlog
|
|
import pytest
|
|
from talemate.util.async_tools import cleanup_pending_tasks
|
|
|
|
log = structlog.get_logger()
|
|
|
|
class Counter(Node):
|
|
def __init__(self, title="Counter", **kwargs):
|
|
super().__init__(title=title, **kwargs)
|
|
|
|
def setup(self):
|
|
self.add_input("state")
|
|
self.add_output("value")
|
|
self.set_property("counter", 0)
|
|
|
|
async def run(self, state: GraphState):
|
|
counter = self.get_property("counter")
|
|
self.set_output_values({
|
|
"value": counter
|
|
})
|
|
self.set_property("counter", counter + 1, state)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_graph():
|
|
# Create nodes
|
|
node_a = Node(title="A")
|
|
node_b = Node(title="B")
|
|
node_c = Node(title="C")
|
|
node_d = Node(title="D")
|
|
|
|
# Add sockets to nodes
|
|
out_a1 = node_a.add_output("out1")
|
|
out_a2 = node_a.add_output("out2")
|
|
|
|
in_b = node_b.add_input("in")
|
|
out_b = node_b.add_output("out")
|
|
|
|
in_c = node_c.add_input("in")
|
|
out_c = node_c.add_output("out")
|
|
|
|
in_d1 = node_d.add_input("in1")
|
|
in_d2 = node_d.add_input("in2")
|
|
|
|
# Create graph
|
|
graph = Graph()
|
|
graph.add_node(node_a)
|
|
graph.add_node(node_b)
|
|
graph.add_node(node_c)
|
|
graph.add_node(node_d)
|
|
|
|
# Connect nodes via sockets
|
|
graph.connect(out_a1, in_b) # A -> B
|
|
graph.connect(out_a2, in_c) # A -> C
|
|
graph.connect(out_b, in_d1) # B -> D
|
|
graph.connect(out_c, in_d2) # C -> D
|
|
|
|
nxgraph = graph.build()
|
|
|
|
# Print paths
|
|
print([graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id)])
|
|
print([graph.node(n).title for n in nx.topological_sort(nxgraph)])
|
|
|
|
# Add assertions for expected behavior
|
|
shortest_path = [graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id)]
|
|
topo_sort = [graph.node(n).title for n in nx.topological_sort(nxgraph)]
|
|
|
|
assert len(shortest_path) == 3, "Shortest path should have 3 nodes"
|
|
assert shortest_path[0] == "A", "Path should start with A"
|
|
assert shortest_path[-1] == "D", "Path should end with D"
|
|
assert len(topo_sort) == 4, "Should have all 4 nodes in topological sort"
|
|
assert topo_sort[0] == "A", "Topological sort should start with A"
|
|
assert topo_sort[-1] == "D", "Topological sort should end with D"
|
|
|
|
await cleanup_pending_tasks()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_data_flow():
|
|
# Create nodes with specific behaviors
|
|
class NodeA(Node):
|
|
def __init__(self):
|
|
super().__init__(title="A")
|
|
self.add_output("out1")
|
|
self.add_output("out2")
|
|
|
|
async def run(self, state: GraphState):
|
|
# Output constant values for testing
|
|
self.set_output_values({
|
|
"out1": 5,
|
|
"out2": 10
|
|
})
|
|
|
|
class NodeB(Node):
|
|
def __init__(self):
|
|
super().__init__(title="B")
|
|
self.add_input("in")
|
|
self.add_output("out")
|
|
|
|
async def run(self, state: GraphState):
|
|
inputs = self.get_input_values()
|
|
# Double the input value
|
|
self.set_output_values({
|
|
"out": inputs["in"] * 2
|
|
})
|
|
|
|
class NodeC(Node):
|
|
def __init__(self):
|
|
super().__init__(title="C")
|
|
self.add_input("in")
|
|
self.add_output("out")
|
|
|
|
async def run(self, state: GraphState):
|
|
inputs = self.get_input_values()
|
|
# Add 1 to the input value
|
|
self.set_output_values({
|
|
"out": inputs["in"] + 1
|
|
})
|
|
|
|
class NodeD(Node):
|
|
result: int = 0
|
|
|
|
def __init__(self):
|
|
super().__init__(title="D")
|
|
self.add_input("in1")
|
|
self.add_input("in2")
|
|
|
|
async def run(self, state: GraphState):
|
|
inputs = self.get_input_values()
|
|
# Store sum for testing
|
|
self.result = inputs["in1"] + inputs["in2"]
|
|
|
|
# Create nodes
|
|
node_a = NodeA()
|
|
node_b = NodeB()
|
|
node_c = NodeC()
|
|
node_d = NodeD()
|
|
|
|
# Create graph
|
|
graph = Graph()
|
|
graph.add_node(node_a)
|
|
graph.add_node(node_b)
|
|
graph.add_node(node_c)
|
|
graph.add_node(node_d)
|
|
|
|
# Connect nodes via sockets
|
|
graph.connect(node_a.outputs[0], node_b.inputs[0]) # A.out1 -> B.in
|
|
graph.connect(node_a.outputs[1], node_c.inputs[0]) # A.out2 -> C.in
|
|
graph.connect(node_b.outputs[0], node_d.inputs[0]) # B.out -> D.in1
|
|
graph.connect(node_c.outputs[0], node_d.inputs[1]) # C.out -> D.in2
|
|
|
|
async def assert_state(state: GraphState):
|
|
print(state.data)
|
|
# Test data flow
|
|
# NodeA outputs: out1=5, out2=10
|
|
assert node_a.outputs[0].value == 5, "NodeA out1 should be 5"
|
|
assert node_a.outputs[1].value == 10, "NodeA out2 should be 10"
|
|
|
|
# NodeB doubles input: 5 * 2 = 10
|
|
assert node_b.outputs[0].value == 10, "NodeB should double input value"
|
|
|
|
# NodeC adds 1: 10 + 1 = 11
|
|
assert node_c.outputs[0].value == 11, "NodeC should add 1 to input value"
|
|
|
|
# NodeD sums inputs: 10 + 11 = 21
|
|
assert node_d.result == 21, "NodeD should sum its inputs"
|
|
|
|
# Execute graph
|
|
graph.callbacks.append(assert_state)
|
|
await graph.execute()
|
|
await cleanup_pending_tasks()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_property_flow():
|
|
# Create nodes with property-driven behaviors
|
|
class NumberSource(Node):
|
|
def __init__(self):
|
|
super().__init__(title="Source")
|
|
self.add_output("value")
|
|
# Set default property
|
|
self.set_property("value", 5)
|
|
|
|
async def run(self, state: GraphState):
|
|
# Output property value
|
|
self.set_output_values({
|
|
"value": self.get_property("value")
|
|
})
|
|
|
|
class Multiplier(Node):
|
|
def __init__(self):
|
|
super().__init__(title="Multiplier")
|
|
self.add_input("value")
|
|
self.add_output("result")
|
|
# Set default multiplier
|
|
self.set_property("multiplier", 2)
|
|
|
|
async def run(self, state: GraphState):
|
|
inputs = self.get_input_values()
|
|
multiplier = self.get_input_value("multiplier") # Will fall back to property
|
|
|
|
print("Multiplier input:", inputs["value"], "Multiplier:", multiplier)
|
|
self.set_output_values({
|
|
"result": (inputs["value"] or 0) * multiplier
|
|
})
|
|
|
|
class Adder(Node):
|
|
def __init__(self):
|
|
super().__init__(title="Adder")
|
|
self.add_input("value")
|
|
self.add_output("result")
|
|
# Set default addend
|
|
self.set_property("addend", 1)
|
|
|
|
async def run(self, state: GraphState):
|
|
inputs = self.get_input_values()
|
|
addend = self.get_input_value("addend") # Will fall back to property
|
|
self.set_output_values({
|
|
"result": inputs["value"] + addend
|
|
})
|
|
|
|
class Collector(Node):
|
|
result: float = 0
|
|
|
|
def __init__(self):
|
|
super().__init__(title="Collector")
|
|
self.add_input("value1")
|
|
self.add_input("value2")
|
|
# Set default values
|
|
self.set_property("value1", 0)
|
|
self.set_property("value2", 0)
|
|
|
|
async def run(self, state: GraphState):
|
|
inputs = self.get_input_values()
|
|
self.result = inputs["value1"] + inputs["value2"]
|
|
|
|
# Create nodes and graph setup...
|
|
source = NumberSource()
|
|
mult = Multiplier()
|
|
add = Adder()
|
|
collect = Collector()
|
|
|
|
# Create graph
|
|
graph = Graph()
|
|
graph.add_node(source)
|
|
graph.add_node(mult)
|
|
graph.add_node(add)
|
|
graph.add_node(collect)
|
|
|
|
# Connect nodes
|
|
graph.connect(source.outputs[0], mult.inputs[0]) # Source -> Multiplier
|
|
graph.connect(source.outputs[0], add.inputs[0]) # Source -> Adder
|
|
graph.connect(mult.outputs[0], collect.inputs[0]) # Multiplier -> Collector.value1
|
|
graph.connect(add.outputs[0], collect.inputs[1]) # Adder -> Collector.value2
|
|
|
|
|
|
async def assert_state(state:GraphState):
|
|
# Run assertions...
|
|
assert source.outputs[0].value == 5, "Source should output property value"
|
|
assert mult.outputs[0].value == 10, "Multiplier should use property multiplier"
|
|
assert add.outputs[0].value == 6, "Adder should use property addend"
|
|
assert collect.result == 16, "Collector should sum multiplier and adder outputs"
|
|
|
|
# Test property defaults
|
|
graph.callbacks.append(assert_state)
|
|
await graph.execute()
|
|
await cleanup_pending_tasks()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_loop():
|
|
entry_loop = Entry()
|
|
counter = Counter()
|
|
|
|
loop = Loop(exit_condition=lambda state: counter.get_property("counter") > 10)
|
|
loop.add_node(entry_loop)
|
|
loop.add_node(counter)
|
|
loop.connect(entry_loop.outputs[0], counter.inputs[0])
|
|
|
|
entry = Entry()
|
|
graph = Graph()
|
|
|
|
graph.add_node(entry)
|
|
graph.add_node(loop)
|
|
|
|
graph.connect(entry.outputs[0], loop.inputs[0])
|
|
|
|
async def assert_state(state: GraphState):
|
|
assert counter.outputs[0].value == 10, "Counter should count to 10"
|
|
|
|
loop.callbacks.append(assert_state)
|
|
await graph.execute()
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_fork():
|
|
entry = Entry(title="Entry")
|
|
entry_loop = Entry(title="Entry Loop")
|
|
|
|
counter_main = Counter("CNT Main")
|
|
counter_a = Counter("CNT A")
|
|
counter_b = Counter("CNT B")
|
|
router = Router(2, selector=lambda state: 0 if counter_main.get_property("counter") % 2 == 0 else 1)
|
|
|
|
loop = Loop(title="Loop", exit_condition=lambda state: counter_main.get_property("counter") > 10)
|
|
|
|
loop.add_node(entry_loop)
|
|
loop.add_node(counter_main)
|
|
loop.add_node(counter_a)
|
|
loop.add_node(counter_b)
|
|
loop.add_node(router)
|
|
|
|
loop.connect(entry_loop.outputs[0], counter_main.inputs[0])
|
|
loop.connect(counter_main.outputs[0], router.inputs[0])
|
|
loop.connect(router.outputs[0], counter_a.inputs[0])
|
|
loop.connect(router.outputs[1], counter_b.inputs[0])
|
|
|
|
graph = Graph()
|
|
graph.add_node(entry)
|
|
graph.add_node(loop)
|
|
|
|
graph.connect(entry.outputs[0], loop.inputs[0])
|
|
|
|
async def assert_state_loop(state: GraphState):
|
|
assert counter_main.get_property("counter") == 11, "Main counter should count to 11"
|
|
assert counter_a.get_property("counter") == 5, "Counter A should count to 5"
|
|
assert counter_b.get_property("counter") == 5, "Counter B should count to 5"
|
|
|
|
loop.callbacks.append(assert_state_loop)
|
|
|
|
await graph.execute()
|
|
await cleanup_pending_tasks()
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_visited_paths():
|
|
"""Test that nodes can be visited through multiple paths"""
|
|
# Create a simple graph where node A connects to B both directly and through C
|
|
# A -> B
|
|
# A -> C -> B
|
|
# Only one path gets deactivated, other should still work
|
|
|
|
graph = Graph()
|
|
|
|
# Create nodes
|
|
node_a = Node(title="Node A")
|
|
node_b = Node(title="Node B")
|
|
node_c = Node(title="Node C")
|
|
|
|
# Add nodes to graph
|
|
graph.add_node(node_a)
|
|
graph.add_node(node_b)
|
|
graph.add_node(node_c)
|
|
|
|
# Create sockets
|
|
a_out1 = node_a.add_output("out1")
|
|
a_out2 = node_a.add_output("out2")
|
|
b_in1 = node_b.add_input("in1")
|
|
b_in2 = node_b.add_input("in2")
|
|
c_in = node_c.add_input("in")
|
|
c_out = node_c.add_output("out")
|
|
|
|
# Connect nodes
|
|
# A -> B (direct path)
|
|
graph.connect(a_out1, b_in1)
|
|
# A -> C -> B (indirect path)
|
|
graph.connect(a_out2, c_in)
|
|
graph.connect(c_out, b_in2)
|
|
|
|
with GraphContext() as state:
|
|
# Deactivate the direct path
|
|
a_out1.deactivated = True
|
|
|
|
# Node A should still be available because the path through C is still active
|
|
assert node_a.check_is_available(state), "Node A should be available through path via C"
|
|
|
|
# Now deactivate the indirect path too
|
|
a_out2.deactivated = True
|
|
|
|
# Now Node A should be unavailable as all paths are deactivated
|
|
assert not node_a.check_is_available(state), "Node A should be unavailable when all paths are deactivated"
|
|
|
|
await cleanup_pending_tasks()
|
|
|