from __future__ import annotations from .grammar import COMPARATOR_VERBS, LOGICAL_VERBS from .parser import Node, Strategy KNOWN_INDICATORS: frozenset[str] = frozenset({"sma", "rsi", "atr", "macd", "realized_vol"}) KNOWN_FEATURES: frozenset[str] = frozenset({"open", "high", "low", "close", "volume"}) # Numero di parametri numerici accettati dopo il nome dell'indicatore. # La tupla (min, max) include solo i numeri (gli argomenti di tipo Node sono # proibiti dal compiler - gli indicatori non sono annidabili in Phase 1). INDICATOR_ARITY: dict[str, tuple[int, int]] = { "sma": (1, 1), # length "rsi": (1, 1), # length "atr": (1, 1), # length "realized_vol": (1, 1), # window "macd": (0, 3), # fast, slow, signal (tutti opzionali con default) } class ValidationError(Exception): """Raised when an AST violates Phase 1 protocol semantics.""" def validate_strategy(strategy: Strategy) -> None: """Check semantic constraints on a parsed Strategy AST. The parser already enforces verb-set membership; this pass adds: * arity checks for logical/comparator/data verbs, * known-indicator / known-feature whitelists. """ for rule in strategy.rules: _validate_node(rule.condition, _expect_bool=True) def _validate_node(node: Node, _expect_bool: bool) -> None: if node.kind in LOGICAL_VERBS: if node.kind == "not": if len(node.args) != 1: raise ValidationError(f"'not' needs 1 arg, got {len(node.args)}") arg = node.args[0] if isinstance(arg, Node): _validate_node(arg, _expect_bool=True) else: if len(node.args) < 2: raise ValidationError(f"'{node.kind}' needs >=2 args") for a in node.args: if isinstance(a, Node): _validate_node(a, _expect_bool=True) return if node.kind in COMPARATOR_VERBS: if len(node.args) != 2: raise ValidationError(f"'{node.kind}' needs 2 args, got {len(node.args)}") for a in node.args: if isinstance(a, Node): _validate_node(a, _expect_bool=False) return if node.kind in {"crossover", "crossunder"}: if len(node.args) != 2: raise ValidationError(f"'{node.kind}' needs 2 args") for a in node.args: if isinstance(a, Node): _validate_node(a, _expect_bool=False) return if node.kind == "indicator": if len(node.args) < 1: raise ValidationError("'indicator' needs >=1 args (name [, params...])") name_node = node.args[0] ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node) if ind_name not in KNOWN_INDICATORS: raise ValidationError(f"unknown indicator: {ind_name}") # Gli indicatori non accettano Node come params (no-nesting in Phase 1). for a in node.args[1:]: if isinstance(a, Node): raise ValidationError( f"indicator '{ind_name}' does not accept nested expressions; " f"only numeric literals (got node {a.kind})" ) n_params = len(node.args) - 1 min_p, max_p = INDICATOR_ARITY[ind_name] if not (min_p <= n_params <= max_p): raise ValidationError( f"indicator '{ind_name}' arity {n_params} out of [{min_p},{max_p}]" ) return if node.kind == "feature": if len(node.args) != 1: raise ValidationError("'feature' needs 1 arg") feat_node = node.args[0] feat_name = feat_node.kind if isinstance(feat_node, Node) else str(feat_node) if feat_name not in KNOWN_FEATURES: raise ValidationError(f"unknown feature: {feat_name}") return raise ValidationError(f"unexpected node kind in expression: {node.kind}")