im_wower
·
2026-04-01
test_kernel_sanity.py
1"""
2Phase 2.2: Kernel Sanity Gate Tests
3=====================================
4验证 Branch B 内核的 4 个关键接线:
51. asymmetry / backward-weight 读取正确性
62. Dirichlet per-node wiring(cat0/cat1/cat2 真正被更新)
73. context 参数真实消费
84. attention ledger 与 mu 对齐
9"""
10
11import sys
12import os
13import math
14
15sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16from cie import CIERuntime
17from cie.graph import Graph
18
19
20def test_asymmetry_correct():
21 """asymmetry_at 使用正确的反向权重"""
22 g = Graph()
23 g.add_edge('a', 'b', weight=2.0, bwd_weight=0.5)
24
25 phi = {'a': 1.0, 'b': 1.0}
26
27 # asymmetry_at('a') 应该用 fwd('a','b')=2.0 和 bwd('a','b')
28 # bwd_weight 是 b→a 方向的权重 = 0.5
29 # asymmetry = (fwd - bwd) * phi(b) = (2.0 - 0.5) * 1.0 = 1.5
30 asym = g.asymmetry_at('a', phi)
31 expected = (2.0 - 0.5) * 1.0 # 1.5
32
33 assert abs(asym - expected) < 0.01, \
34 f"asymmetry_at wrong: got {asym}, expected {expected}"
35 print(f" PASS: asymmetry_at('a')={asym:.2f}, expected={expected:.2f}")
36
37
38def test_bwd_weight_in_to_dict():
39 """to_dict 导出正确的反向权重"""
40 g = Graph()
41 g.add_edge('x', 'y', weight=3.0, bwd_weight=1.0)
42
43 d = g.to_dict()
44 edges = d['edges']
45 xy_edge = next(e for e in edges if e['src'] == 'x' and e['dst'] == 'y')
46
47 assert abs(xy_edge['weight'] - 3.0) < 0.01, \
48 f"fwd weight wrong: {xy_edge['weight']}"
49 assert abs(xy_edge['bwd_weight'] - 1.0) < 0.01, \
50 f"bwd weight wrong: {xy_edge['bwd_weight']}, expected 1.0"
51 print(f" PASS: to_dict bwd_weight={xy_edge['bwd_weight']:.2f}")
52
53
54def test_dirichlet_per_node_wiring():
55 """普通 ingest 真正更新 per-node Dirichlet cat0/cat1"""
56 rt = CIERuntime(seed=42)
57 rt.ingest("你好世界")
58 rt.step(n=1) # 消费信号
59
60 # bigram 你→好:你应该有 cat0 更新,好应该有 cat1 更新
61 conf_ni = rt.state.confidence.get('你', None)
62 conf_hao = rt.state.confidence.get('好', None)
63
64 assert conf_ni is not None, "Node '你' has no confidence"
65 assert conf_hao is not None, "Node '好' has no confidence"
66
67 # cat0(左上下文)应该 > 1.0(先验)对于 '你'
68 assert conf_ni[0] > 1.0, \
69 f"'你' cat0 not updated: {conf_ni}"
70
71 # cat1(右上下文)应该 > 1.0 对于 '好'
72 assert conf_hao[1] > 1.0, \
73 f"'好' cat1 not updated: {conf_hao}"
74
75 # 不是全都一样(打破了均匀先验)
76 assert not (abs(conf_ni[0] - conf_ni[1]) < 0.01 and abs(conf_ni[1] - conf_ni[2]) < 0.01), \
77 f"'你' Dirichlet still uniform: {conf_ni}"
78
79 print(f" PASS: Dirichlet per-node — '你'={[round(x,2) for x in conf_ni]}, "
80 f"'好'={[round(x,2) for x in conf_hao]}")
81
82
83def test_context_consumed():
84 """context 参数被真实消费——建立了 context→token 的边"""
85 rt = CIERuntime(seed=42)
86 rt.ingest("你好", context="背景")
87 rt.step(n=1)
88
89 # context 字符应该存在于图中
90 assert rt.graph.has_node('背'), f"Context char '背' not in graph"
91 assert rt.graph.has_node('景'), f"Context char '景' not in graph"
92
93 # context 字符与主 token 之间应有边
94 has_edge = False
95 for ctx_char in ['背', '景']:
96 for tok_char in ['你', '好']:
97 if rt.graph.get_edge_weight(ctx_char, tok_char) > 0:
98 has_edge = True
99 break
100 if has_edge:
101 break
102
103 assert has_edge, "No edges between context and tokens"
104 print(f" PASS: context consumed — '背','景' in graph with edges to tokens")
105
106
107def test_attention_mu_alignment():
108 """attention.used 与 sum(mu) 在受控场景下对齐"""
109 rt = CIERuntime(seed=42)
110 rt.ingest("测试")
111 rt.step(n=5)
112
113 mu_sum = sum(v for v in rt.state.mu.values() if v > 0)
114 attn_used = rt.state.attention.used
115
116 # 允许 20% 偏差(传播中有衰减和清理)
117 if mu_sum > 0.1:
118 ratio = attn_used / mu_sum if mu_sum > 0 else float('inf')
119 # attention 应该 >= mu(因为还包含已衰减但未释放的部分)
120 # 但不应该差太多
121 assert 0.7 < ratio < 2.0, \
122 f"attention/mu drift: attn={attn_used:.2f}, mu={mu_sum:.2f}, ratio={ratio:.2f}"
123
124 print(f" PASS: attention/mu aligned — attn={attn_used:.2f}, mu={mu_sum:.2f}")
125
126
127def test_attention_no_unbounded_drift():
128 """多轮运行后 attention 不会无界漂移"""
129 rt = CIERuntime(seed=42)
130
131 for i in range(20):
132 rt.ingest(f"轮次{i}")
133 rt.step(n=3)
134
135 attn = rt.state.attention
136 assert attn.used <= attn.total + 0.01, \
137 f"Attention overflow: used={attn.used:.2f} > total={attn.total:.2f}"
138 assert attn.free >= -0.01, \
139 f"Attention went negative: free={attn.free:.2f}"
140
141 mu_sum = sum(v for v in rt.state.mu.values() if v > 0)
142 # Stricter: attention should track mu within 50%
143 if mu_sum > 1.0:
144 drift_ratio = attn.used / mu_sum
145 assert 0.5 < drift_ratio < 2.0, f"attention drifted: attn={attn.used:.2f}, mu={mu_sum:.2f}, ratio={drift_ratio:.2f}"
146
147 print(f" PASS: no unbounded drift after 20 rounds — "
148 f"attn_used={attn.used:.2f}, mu_sum={mu_sum:.2f}, free={attn.free:.2f}")
149
150
151def run_all():
152 tests = [
153 ("asymmetry_correct", test_asymmetry_correct),
154 ("bwd_weight_to_dict", test_bwd_weight_in_to_dict),
155 ("dirichlet_per_node", test_dirichlet_per_node_wiring),
156 ("context_consumed", test_context_consumed),
157 ("attention_mu_align", test_attention_mu_alignment),
158 ("attention_no_drift", test_attention_no_unbounded_drift),
159 ]
160
161 passed = 0
162 failed = 0
163 for name, fn in tests:
164 try:
165 print(f"[KERNEL] {name}...")
166 fn()
167 passed += 1
168 except Exception as e:
169 print(f" FAIL: {e}")
170 failed += 1
171
172 print(f"\n{'='*50}")
173 print(f"Kernel Sanity Gate: {passed} passed, {failed} failed, {len(tests)} total")
174 print(f"{'='*50}")
175 return failed == 0
176
177
178if __name__ == '__main__':
179 success = run_all()
180 sys.exit(0 if success else 1)