CIE-Unified

git clone 

CIE-Unified / tests
im_wower  ·  2026-04-01

test_kernel_sanity.py

  1"""
  2Phase 2.2: Kernel Sanity Gate Tests
  3=====================================
  4验证 Branch B 内核的 4 个关键接线:
  51. asymmetry / backward-weight 读取正确性
  62. Dirichlet per-node wiring(cat0/cat1/cat2 真正被更新)
  73. context 参数真实消费
  84. attention ledger 与 mu 对齐
  9"""
 10
 11import sys
 12import os
 13import math
 14
 15sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 16from cie import CIERuntime
 17from cie.graph import Graph
 18
 19
 20def test_asymmetry_correct():
 21    """asymmetry_at 使用正确的反向权重"""
 22    g = Graph()
 23    g.add_edge('a', 'b', weight=2.0, bwd_weight=0.5)
 24
 25    phi = {'a': 1.0, 'b': 1.0}
 26
 27    # asymmetry_at('a') 应该用 fwd('a','b')=2.0 和 bwd('a','b')
 28    # bwd_weight 是 b→a 方向的权重 = 0.5
 29    # asymmetry = (fwd - bwd) * phi(b) = (2.0 - 0.5) * 1.0 = 1.5
 30    asym = g.asymmetry_at('a', phi)
 31    expected = (2.0 - 0.5) * 1.0  # 1.5
 32
 33    assert abs(asym - expected) < 0.01, \
 34        f"asymmetry_at wrong: got {asym}, expected {expected}"
 35    print(f"  PASS: asymmetry_at('a')={asym:.2f}, expected={expected:.2f}")
 36
 37
 38def test_bwd_weight_in_to_dict():
 39    """to_dict 导出正确的反向权重"""
 40    g = Graph()
 41    g.add_edge('x', 'y', weight=3.0, bwd_weight=1.0)
 42
 43    d = g.to_dict()
 44    edges = d['edges']
 45    xy_edge = next(e for e in edges if e['src'] == 'x' and e['dst'] == 'y')
 46
 47    assert abs(xy_edge['weight'] - 3.0) < 0.01, \
 48        f"fwd weight wrong: {xy_edge['weight']}"
 49    assert abs(xy_edge['bwd_weight'] - 1.0) < 0.01, \
 50        f"bwd weight wrong: {xy_edge['bwd_weight']}, expected 1.0"
 51    print(f"  PASS: to_dict bwd_weight={xy_edge['bwd_weight']:.2f}")
 52
 53
 54def test_dirichlet_per_node_wiring():
 55    """普通 ingest 真正更新 per-node Dirichlet cat0/cat1"""
 56    rt = CIERuntime(seed=42)
 57    rt.ingest("你好世界")
 58    rt.step(n=1)  # 消费信号
 59
 60    # bigram 你→好:你应该有 cat0 更新,好应该有 cat1 更新
 61    conf_ni = rt.state.confidence.get('', None)
 62    conf_hao = rt.state.confidence.get('', None)
 63
 64    assert conf_ni is not None, "Node '' has no confidence"
 65    assert conf_hao is not None, "Node '' has no confidence"
 66
 67    # cat0(左上下文)应该 > 1.0(先验)对于 '你'
 68    assert conf_ni[0] > 1.0, \
 69        f"'' cat0 not updated: {conf_ni}"
 70
 71    # cat1(右上下文)应该 > 1.0 对于 '好'
 72    assert conf_hao[1] > 1.0, \
 73        f"'' cat1 not updated: {conf_hao}"
 74
 75    # 不是全都一样(打破了均匀先验)
 76    assert not (abs(conf_ni[0] - conf_ni[1]) < 0.01 and abs(conf_ni[1] - conf_ni[2]) < 0.01), \
 77        f"'' Dirichlet still uniform: {conf_ni}"
 78
 79    print(f"  PASS: Dirichlet per-node — ''={[round(x,2) for x in conf_ni]}, "
 80          f"''={[round(x,2) for x in conf_hao]}")
 81
 82
 83def test_context_consumed():
 84    """context 参数被真实消费——建立了 context→token 的边"""
 85    rt = CIERuntime(seed=42)
 86    rt.ingest("你好", context="背景")
 87    rt.step(n=1)
 88
 89    # context 字符应该存在于图中
 90    assert rt.graph.has_node(''), f"Context char '' not in graph"
 91    assert rt.graph.has_node(''), f"Context char '' not in graph"
 92
 93    # context 字符与主 token 之间应有边
 94    has_edge = False
 95    for ctx_char in ['', '']:
 96        for tok_char in ['', '']:
 97            if rt.graph.get_edge_weight(ctx_char, tok_char) > 0:
 98                has_edge = True
 99                break
100        if has_edge:
101            break
102
103    assert has_edge, "No edges between context and tokens"
104    print(f"  PASS: context consumed — '','' in graph with edges to tokens")
105
106
107def test_attention_mu_alignment():
108    """attention.used 与 sum(mu) 在受控场景下对齐"""
109    rt = CIERuntime(seed=42)
110    rt.ingest("测试")
111    rt.step(n=5)
112
113    mu_sum = sum(v for v in rt.state.mu.values() if v > 0)
114    attn_used = rt.state.attention.used
115
116    # 允许 20% 偏差(传播中有衰减和清理)
117    if mu_sum > 0.1:
118        ratio = attn_used / mu_sum if mu_sum > 0 else float('inf')
119        # attention 应该 >= mu(因为还包含已衰减但未释放的部分)
120        # 但不应该差太多
121        assert 0.7 < ratio < 2.0, \
122            f"attention/mu drift: attn={attn_used:.2f}, mu={mu_sum:.2f}, ratio={ratio:.2f}"
123
124    print(f"  PASS: attention/mu aligned — attn={attn_used:.2f}, mu={mu_sum:.2f}")
125
126
127def test_attention_no_unbounded_drift():
128    """多轮运行后 attention 不会无界漂移"""
129    rt = CIERuntime(seed=42)
130
131    for i in range(20):
132        rt.ingest(f"轮次{i}")
133        rt.step(n=3)
134
135    attn = rt.state.attention
136    assert attn.used <= attn.total + 0.01, \
137        f"Attention overflow: used={attn.used:.2f} > total={attn.total:.2f}"
138    assert attn.free >= -0.01, \
139        f"Attention went negative: free={attn.free:.2f}"
140
141    mu_sum = sum(v for v in rt.state.mu.values() if v > 0)
142    # Stricter: attention should track mu within 50%
143    if mu_sum > 1.0:
144        drift_ratio = attn.used / mu_sum
145        assert 0.5 < drift_ratio < 2.0,             f"attention drifted: attn={attn.used:.2f}, mu={mu_sum:.2f}, ratio={drift_ratio:.2f}"
146
147    print(f"  PASS: no unbounded drift after 20 rounds — "
148          f"attn_used={attn.used:.2f}, mu_sum={mu_sum:.2f}, free={attn.free:.2f}")
149
150
151def run_all():
152    tests = [
153        ("asymmetry_correct", test_asymmetry_correct),
154        ("bwd_weight_to_dict", test_bwd_weight_in_to_dict),
155        ("dirichlet_per_node", test_dirichlet_per_node_wiring),
156        ("context_consumed", test_context_consumed),
157        ("attention_mu_align", test_attention_mu_alignment),
158        ("attention_no_drift", test_attention_no_unbounded_drift),
159    ]
160
161    passed = 0
162    failed = 0
163    for name, fn in tests:
164        try:
165            print(f"[KERNEL] {name}...")
166            fn()
167            passed += 1
168        except Exception as e:
169            print(f"  FAIL: {e}")
170            failed += 1
171
172    print(f"\n{'='*50}")
173    print(f"Kernel Sanity Gate: {passed} passed, {failed} failed, {len(tests)} total")
174    print(f"{'='*50}")
175    return failed == 0
176
177
178if __name__ == '__main__':
179    success = run_all()
180    sys.exit(0 if success else 1)