feat(protocol): semantic validator for AST

Aggiunge validatore semantico per AST Strategy: arity check su
verbi logici/comparatori/data, whitelist indicatori
(sma, rsi, atr, macd, realized_vol) e feature OHLCV.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-09 19:34:03 +02:00
parent 19035812c3
commit 052f323790
2 changed files with 113 additions and 0 deletions
+75
View File
@@ -0,0 +1,75 @@
from __future__ import annotations
from .grammar import COMPARATOR_VERBS, LOGICAL_VERBS
from .parser import Node, Strategy
KNOWN_INDICATORS: frozenset[str] = frozenset({"sma", "rsi", "atr", "macd", "realized_vol"})
KNOWN_FEATURES: frozenset[str] = frozenset({"open", "high", "low", "close", "volume"})
class ValidationError(Exception):
"""Raised when an AST violates Phase 1 protocol semantics."""
def validate_strategy(strategy: Strategy) -> None:
"""Check semantic constraints on a parsed Strategy AST.
The parser already enforces verb-set membership; this pass adds:
* arity checks for logical/comparator/data verbs,
* known-indicator / known-feature whitelists.
"""
for rule in strategy.rules:
_validate_node(rule.condition, _expect_bool=True)
def _validate_node(node: Node, _expect_bool: bool) -> None:
if node.kind in LOGICAL_VERBS:
if node.kind == "not":
if len(node.args) != 1:
raise ValidationError(f"'not' needs 1 arg, got {len(node.args)}")
arg = node.args[0]
if isinstance(arg, Node):
_validate_node(arg, _expect_bool=True)
else:
if len(node.args) < 2:
raise ValidationError(f"'{node.kind}' needs >=2 args")
for a in node.args:
if isinstance(a, Node):
_validate_node(a, _expect_bool=True)
return
if node.kind in COMPARATOR_VERBS:
if len(node.args) != 2:
raise ValidationError(f"'{node.kind}' needs 2 args, got {len(node.args)}")
for a in node.args:
if isinstance(a, Node):
_validate_node(a, _expect_bool=False)
return
if node.kind in {"crossover", "crossunder"}:
if len(node.args) != 2:
raise ValidationError(f"'{node.kind}' needs 2 args")
for a in node.args:
if isinstance(a, Node):
_validate_node(a, _expect_bool=False)
return
if node.kind == "indicator":
if len(node.args) < 2:
raise ValidationError("'indicator' needs >=2 args (name, length)")
name_node = node.args[0]
ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node)
if ind_name not in KNOWN_INDICATORS:
raise ValidationError(f"unknown indicator: {ind_name}")
return
if node.kind == "feature":
if len(node.args) != 1:
raise ValidationError("'feature' needs 1 arg")
feat_node = node.args[0]
feat_name = feat_node.kind if isinstance(feat_node, Node) else str(feat_node)
if feat_name not in KNOWN_FEATURES:
raise ValidationError(f"unknown feature: {feat_name}")
return
raise ValidationError(f"unexpected node kind in expression: {node.kind}")
+38
View File
@@ -0,0 +1,38 @@
import pytest
from multi_swarm.protocol.parser import parse_strategy
from multi_swarm.protocol.validator import ValidationError, validate_strategy
def test_valid_strategy_passes() -> None:
src = "(strategy (when (gt (indicator rsi 14) 70.0) (entry-short)))"
ast = parse_strategy(src)
validate_strategy(ast) # no exception
def test_indicator_unknown_name_fails() -> None:
src = "(strategy (when (gt (indicator wibble 14) 70.0) (entry-short)))"
ast = parse_strategy(src)
with pytest.raises(ValidationError, match="unknown indicator"):
validate_strategy(ast)
def test_indicator_wrong_arity_fails() -> None:
src = "(strategy (when (gt (indicator rsi) 70.0) (entry-short)))"
ast = parse_strategy(src)
with pytest.raises(ValidationError):
validate_strategy(ast)
def test_comparator_wrong_arity_fails() -> None:
src = "(strategy (when (gt 1.0) (entry-long)))"
ast = parse_strategy(src)
with pytest.raises(ValidationError):
validate_strategy(ast)
def test_feature_unknown_column_fails() -> None:
src = "(strategy (when (gt (feature wibble) 100.0) (entry-long)))"
ast = parse_strategy(src)
with pytest.raises(ValidationError, match="unknown feature"):
validate_strategy(ast)