diff --git a/src/multi_swarm/protocol/validator.py b/src/multi_swarm/protocol/validator.py new file mode 100644 index 0000000..6802d13 --- /dev/null +++ b/src/multi_swarm/protocol/validator.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from .grammar import COMPARATOR_VERBS, LOGICAL_VERBS +from .parser import Node, Strategy + +KNOWN_INDICATORS: frozenset[str] = frozenset({"sma", "rsi", "atr", "macd", "realized_vol"}) +KNOWN_FEATURES: frozenset[str] = frozenset({"open", "high", "low", "close", "volume"}) + + +class ValidationError(Exception): + """Raised when an AST violates Phase 1 protocol semantics.""" + + +def validate_strategy(strategy: Strategy) -> None: + """Check semantic constraints on a parsed Strategy AST. + + The parser already enforces verb-set membership; this pass adds: + * arity checks for logical/comparator/data verbs, + * known-indicator / known-feature whitelists. + """ + for rule in strategy.rules: + _validate_node(rule.condition, _expect_bool=True) + + +def _validate_node(node: Node, _expect_bool: bool) -> None: + if node.kind in LOGICAL_VERBS: + if node.kind == "not": + if len(node.args) != 1: + raise ValidationError(f"'not' needs 1 arg, got {len(node.args)}") + arg = node.args[0] + if isinstance(arg, Node): + _validate_node(arg, _expect_bool=True) + else: + if len(node.args) < 2: + raise ValidationError(f"'{node.kind}' needs >=2 args") + for a in node.args: + if isinstance(a, Node): + _validate_node(a, _expect_bool=True) + return + + if node.kind in COMPARATOR_VERBS: + if len(node.args) != 2: + raise ValidationError(f"'{node.kind}' needs 2 args, got {len(node.args)}") + for a in node.args: + if isinstance(a, Node): + _validate_node(a, _expect_bool=False) + return + + if node.kind in {"crossover", "crossunder"}: + if len(node.args) != 2: + raise ValidationError(f"'{node.kind}' needs 2 args") + for a in node.args: + if isinstance(a, Node): + _validate_node(a, _expect_bool=False) + return + + if node.kind == "indicator": + if len(node.args) < 2: + raise ValidationError("'indicator' needs >=2 args (name, length)") + name_node = node.args[0] + ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node) + if ind_name not in KNOWN_INDICATORS: + raise ValidationError(f"unknown indicator: {ind_name}") + return + + if node.kind == "feature": + if len(node.args) != 1: + raise ValidationError("'feature' needs 1 arg") + feat_node = node.args[0] + feat_name = feat_node.kind if isinstance(feat_node, Node) else str(feat_node) + if feat_name not in KNOWN_FEATURES: + raise ValidationError(f"unknown feature: {feat_name}") + return + + raise ValidationError(f"unexpected node kind in expression: {node.kind}") diff --git a/tests/unit/test_protocol_validator.py b/tests/unit/test_protocol_validator.py new file mode 100644 index 0000000..2aa6b4f --- /dev/null +++ b/tests/unit/test_protocol_validator.py @@ -0,0 +1,38 @@ +import pytest + +from multi_swarm.protocol.parser import parse_strategy +from multi_swarm.protocol.validator import ValidationError, validate_strategy + + +def test_valid_strategy_passes() -> None: + src = "(strategy (when (gt (indicator rsi 14) 70.0) (entry-short)))" + ast = parse_strategy(src) + validate_strategy(ast) # no exception + + +def test_indicator_unknown_name_fails() -> None: + src = "(strategy (when (gt (indicator wibble 14) 70.0) (entry-short)))" + ast = parse_strategy(src) + with pytest.raises(ValidationError, match="unknown indicator"): + validate_strategy(ast) + + +def test_indicator_wrong_arity_fails() -> None: + src = "(strategy (when (gt (indicator rsi) 70.0) (entry-short)))" + ast = parse_strategy(src) + with pytest.raises(ValidationError): + validate_strategy(ast) + + +def test_comparator_wrong_arity_fails() -> None: + src = "(strategy (when (gt 1.0) (entry-long)))" + ast = parse_strategy(src) + with pytest.raises(ValidationError): + validate_strategy(ast) + + +def test_feature_unknown_column_fails() -> None: + src = "(strategy (when (gt (feature wibble) 100.0) (entry-long)))" + ast = parse_strategy(src) + with pytest.raises(ValidationError, match="unknown feature"): + validate_strategy(ast)