diff --git a/pyproject.toml b/pyproject.toml index e2be236..771921b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ dependencies = [ "pydantic>=2.9", "pydantic-settings>=2.6", "sqlmodel>=0.0.22", - "sexpdata>=1.0.2", "openai>=1.55", "httpx>=0.28", "requests>=2.32", diff --git a/scripts/smoke_run.py b/scripts/smoke_run.py index 2c98a2d..16da344 100644 --- a/scripts/smoke_run.py +++ b/scripts/smoke_run.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from pathlib import Path import numpy as np @@ -9,19 +10,40 @@ from multi_swarm.genome.hypothesis import HypothesisAgentGenome, ModelTier from multi_swarm.llm.client import CompletionResult from multi_swarm.orchestrator.run import RunConfig, run_phase1 +_MOCK_STRATEGY = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + }, + { + "condition": { + "op": "lt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 30.0}, + ], + }, + "action": "entry-long", + }, + ] + } +) + class MockLLMClient: def complete( self, genome: HypothesisAgentGenome, system: str, user: str, max_tokens: int = 2000, ) -> CompletionResult: - text = ( - "```lisp\n" - "(strategy" - " (when (gt (indicator rsi 14) 70.0) (entry-short))" - " (when (lt (indicator rsi 14) 30.0) (entry-long)))\n" - "```" - ) + text = "```json\n" + _MOCK_STRATEGY + "\n```" return CompletionResult( text=text, input_tokens=120, output_tokens=60, tier=genome.model_tier, model="mock", diff --git a/src/multi_swarm/agents/hypothesis.py b/src/multi_swarm/agents/hypothesis.py index 6db27cb..b49a6d0 100644 --- a/src/multi_swarm/agents/hypothesis.py +++ b/src/multi_swarm/agents/hypothesis.py @@ -35,42 +35,76 @@ Sei un agente generatore di ipotesi di trading quantitativo per un sistema swarm Il tuo stile cognitivo: {cognitive_style} Direttiva personale: {system_prompt} -Devi proporre una strategia di trading espressa nel linguaggio S-expression -con i seguenti verbi disponibili: +Devi proporre una strategia di trading espressa in JSON STRETTO. +La risposta deve essere un singolo oggetto JSON dentro fence ```json...``` +con questa shape: - Azioni: entry-long, entry-short, exit, flat - Logici: and, or, not - Comparatori: gt, lt, eq - Dati: feature, indicator, crossover, crossunder +```json +{{ + "rules": [ + {{"condition": , "action": "entry-long|entry-short|exit|flat"}} + ] +}} +``` -Indicatori disponibili (calcolati implicitamente sul prezzo close): - sma , rsi , atr , macd, realized_vol . -Feature disponibili: open, high, low, close, volume. +NODI DISPONIBILI -REGOLE STRETTE DI SINTASSI: -- (indicator ) restituisce una serie numerica. Es. - (indicator rsi 14), (indicator sma 50), (indicator macd 12 26 9). -- (feature ) restituisce la colonna OHLCV. Es. (feature close). -- Gli indicatori NON sono annidabili: NON puoi scrivere - (sma (indicator realized_vol 30) 150) o (indicator rsi (feature high) 14). - Le funzioni sma/rsi/etc. ESISTONO SOLO come argomenti di indicator, - non sono verbi indipendenti. -- Costanti numeriche (es. 70.0, 30, 0.02) sono valide come 2° operando di gt/lt/eq. -- crossover/crossunder accettano due espressioni-serie: - (crossover (feature close) (indicator sma 20)) — corretto. - (crossover (sma close 20) (sma close 50)) — ERRATO (sma non è verbo). +Operatori logici: + {{"op": "and", "args": [, , ...]}} // >=2 nodi + {{"op": "or", "args": [, , ...]}} // >=2 nodi + {{"op": "not", "args": []}} // 1 nodo -Le regole sono valutate in ordine; la prima che matcha vince per ogni timestamp. -La default action se nessuna regola matcha è 'flat'. +Comparatori (ritornano boolean series): + {{"op": "gt", "args": [, ]}} // a > b + {{"op": "lt", "args": [, ]}} // a < b + {{"op": "eq", "args": [, ]}} // a == b -Rispondi SOLO con la S-expression in un fence ```lisp ... ```, senza prosa, -senza spiegazioni. Esempio formato: +Crossover (eventi su 2 serie): + {{"op": "crossover", "args": [, ]}} + {{"op": "crossunder", "args": [, ]}} -```lisp -(strategy - (when (gt (indicator rsi 14) 70.0) (entry-short)) - (when (lt (indicator rsi 14) 30.0) (entry-long)) - (when (crossover (feature close) (indicator sma 50)) (entry-long))) +Leaf - indicatori (calcolati su close): + {{"kind": "indicator", "name": "sma", "params": []}} + {{"kind": "indicator", "name": "rsi", "params": []}} + {{"kind": "indicator", "name": "atr", "params": []}} + {{"kind": "indicator", "name": "realized_vol", "params": []}} + {{"kind": "indicator", "name": "macd", "params": [, , ]}} + // 0-3 numeri (tutti opzionali con default 12, 26, 9) + +Leaf - feature OHLCV: + {{"kind": "feature", "name": "open|high|low|close|volume"}} + +Leaf - letterale numerico: + {{"kind": "literal", "value": 70.0}} + +VINCOLI +- Gli indicator NON sono annidabili: 'params' accetta solo numeri, mai altri nodi. +- Le regole sono valutate in ordine; la prima che matcha vince per ogni timestamp. +- Default action se nessuna regola matcha = flat. +- 'op' e 'kind' sono mutuamente esclusivi sullo stesso nodo. + +Rispondi SOLO con il fence ```json...``` contenente l'oggetto strategy. +Esempio: + +```json +{{ + "rules": [ + {{ + "condition": {{"op": "gt", "args": [ + {{"kind": "indicator", "name": "rsi", "params": [14]}}, + {{"kind": "literal", "value": 70.0}} + ]}}, + "action": "entry-short" + }}, + {{ + "condition": {{"op": "lt", "args": [ + {{"kind": "indicator", "name": "rsi", "params": [14]}}, + {{"kind": "literal", "value": 30.0}} + ]}}, + "action": "entry-long" + }} + ] +}} ``` """ @@ -79,7 +113,7 @@ USER_TEMPLATE = """\ Mercato: {symbol} timeframe {timeframe}, {n_bars} barre osservate. Statistiche return: mean={return_mean:.5f}, std={return_std:.5f}, \ skew={skew:.3f}, kurt={kurtosis:.3f}. -Regime volatilità: {volatility_regime}. +Regime volatilità : {volatility_regime}. Feature accessibili dal tuo genoma: {feature_access}. Lookback massimo che puoi usare nel ragionamento: {lookback_window} barre. @@ -88,19 +122,57 @@ Genera una strategia che cerchi anomalie sfruttabili in questo regime. """ -_SEXP_FENCE_RE = re.compile( - r"```(?:lisp|scheme|sexp)?\s*(\(strategy[\s\S]*?\))\s*```", +_JSON_FENCE_RE = re.compile( + r"```(?:json)?\s*(\{[\s\S]*\})\s*```", re.MULTILINE, ) -def _extract_sexp(text: str) -> str | None: - m = _SEXP_FENCE_RE.search(text) +def _balance_braces(s: str) -> str | None: + """Ritorna il prefix di ``s`` che chiude la prima ``{`` con bilanciamento. + + Usato come fallback quando l'LLM ritorna JSON top-level senza fence ma + seguito da prosa: troviamo dove finisce il primo oggetto e tagliamo. + """ + if not s.startswith("{"): + return None + depth = 0 + in_string = False + escape = False + for i, ch in enumerate(s): + if in_string: + if escape: + escape = False + elif ch == "\\": + escape = True + elif ch == '"': + in_string = False + continue + if ch == '"': + in_string = True + elif ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return s[: i + 1] + return None + + +def _extract_json(text: str) -> str | None: + """Estrai un oggetto JSON dal testo del completion. + + Strategie di estrazione, in ordine: + 1. Fence ```json...``` (greedy: cattura fino all'ultimo ``}`` prima della + chiusura del fence). + 2. Testo che inizia direttamente con ``{`` (dopo strip), bilanciato a + livello di parentesi graffe. + """ + m = _JSON_FENCE_RE.search(text) if m: return m.group(1) - if text.strip().startswith("(strategy"): - return text.strip() - return None + stripped = text.strip() + return _balance_braces(stripped) class HypothesisAgent: @@ -131,16 +203,16 @@ class HypothesisAgent: completion = self._llm.complete(genome, system=system, user=user) - sexp = _extract_sexp(completion.text) - if sexp is None: + payload = _extract_json(completion.text) + if payload is None: return HypothesisProposal( strategy=None, raw_text=completion.text, completion=completion, - parse_error="no s-expression found in output", + parse_error="no JSON object found in output", ) try: - ast = parse_strategy(sexp) + ast = parse_strategy(payload) validate_strategy(ast) return HypothesisProposal( strategy=ast, diff --git a/src/multi_swarm/protocol/__init__.py b/src/multi_swarm/protocol/__init__.py index e69de29..2fbc422 100644 --- a/src/multi_swarm/protocol/__init__.py +++ b/src/multi_swarm/protocol/__init__.py @@ -0,0 +1,30 @@ +"""Protocol layer: JSON-based strategy grammar + parser + validator + compiler.""" + +from .compiler import compile_strategy +from .parser import ( + FeatureNode, + IndicatorNode, + LiteralNode, + Node, + OpNode, + ParseError, + Rule, + Strategy, + parse_strategy, +) +from .validator import ValidationError, validate_strategy + +__all__ = [ + "FeatureNode", + "IndicatorNode", + "LiteralNode", + "Node", + "OpNode", + "ParseError", + "Rule", + "Strategy", + "ValidationError", + "compile_strategy", + "parse_strategy", + "validate_strategy", +] diff --git a/src/multi_swarm/protocol/compiler.py b/src/multi_swarm/protocol/compiler.py index 9deefdd..7486fe0 100644 --- a/src/multi_swarm/protocol/compiler.py +++ b/src/multi_swarm/protocol/compiler.py @@ -12,9 +12,9 @@ Design notes a different concrete signature (``(df, length)`` vs ``(df, fast, slow)``); modelling that under ``mypy --strict`` would require a ``Protocol`` per arity, which is overkill for the Phase 1 indicator subset. -* Numeric leaves coming out of :mod:`sexpdata` arrive as ``int`` / ``float`` - / ``str``; we widen via :func:`_to_series` to broadcast them along the - DataFrame index for arithmetic comparisons. +* I parametri di un :class:`IndicatorNode` sono sempre ``float``; cast a + ``int`` per indicatori con argomenti tipo "length" è deferito alle helper + (``_ind_sma``, ecc.) attraverso ``int(...)``. """ from __future__ import annotations @@ -26,7 +26,14 @@ import numpy as np import pandas as pd # type: ignore[import-untyped] from ..backtest.orders import Side -from .parser import Node, Strategy +from .parser import ( + FeatureNode, + IndicatorNode, + LiteralNode, + Node, + OpNode, + Strategy, +) def _sma(s: pd.Series, length: int) -> pd.Series: @@ -61,27 +68,30 @@ def _realized_vol(s: pd.Series, window: int) -> pd.Series: return returns.rolling(window, min_periods=1).std() * np.sqrt(window) -def _ind_sma(df: pd.DataFrame, length: int) -> pd.Series: - return _sma(df["close"], length) +def _ind_sma(df: pd.DataFrame, length: float) -> pd.Series: + return _sma(df["close"], int(length)) -def _ind_rsi(df: pd.DataFrame, length: int) -> pd.Series: - return _rsi(df["close"], length) +def _ind_rsi(df: pd.DataFrame, length: float) -> pd.Series: + return _rsi(df["close"], int(length)) -def _ind_atr(df: pd.DataFrame, length: int) -> pd.Series: - return _atr(df, length) +def _ind_atr(df: pd.DataFrame, length: float) -> pd.Series: + return _atr(df, int(length)) -def _ind_realized_vol(df: pd.DataFrame, window: int) -> pd.Series: - return _realized_vol(df["close"], window) +def _ind_realized_vol(df: pd.DataFrame, window: float) -> pd.Series: + return _realized_vol(df["close"], int(window)) def _ind_macd( - df: pd.DataFrame, fast: int = 12, slow: int = 26, signal: int = 9, + df: pd.DataFrame, + fast: float = 12, + slow: float = 26, + signal: float = 9, ) -> pd.Series: - macd_line = _sma(df["close"], fast) - _sma(df["close"], slow) - signal_line = _sma(macd_line, signal) + macd_line = _sma(df["close"], int(fast)) - _sma(df["close"], int(slow)) + signal_line = _sma(macd_line, int(signal)) return macd_line - signal_line @@ -98,16 +108,9 @@ INDICATOR_FNS: dict[str, Any] = { } -def _to_series(value: object, df: pd.DataFrame) -> pd.Series: +def _to_series(value: float, df: pd.DataFrame) -> pd.Series: """Broadcast a numeric literal across the DataFrame index.""" - return pd.Series(float(value), index=df.index) # type: ignore[arg-type] - - -def _eval_arg(arg: Any, df: pd.DataFrame) -> pd.Series: - """Evaluate either a child Node or a scalar literal into a Series.""" - if isinstance(arg, Node): - return _eval_node(arg, df) - return _to_series(arg, df) + return pd.Series(float(value), index=df.index) def _compare_with_nan(result: pd.Series, a: pd.Series, b: pd.Series) -> pd.Series: @@ -124,71 +127,60 @@ def _compare_with_nan(result: pd.Series, a: pd.Series, b: pd.Series) -> pd.Serie return out -def _eval_bool_arg(arg: Any, df: pd.DataFrame) -> pd.Series: - """Evaluate either a child Node (bool series) or a literal into a bool Series.""" - if isinstance(arg, Node): - return _eval_node(arg, df).fillna(False).astype(bool) - return pd.Series(bool(arg), index=df.index) +def _eval_bool_arg(node: Node, df: pd.DataFrame) -> pd.Series: + """Evaluate a child Node into a boolean Series (NaN -> False).""" + return _eval_node(node, df).fillna(False).astype(bool) def _eval_node(node: Node, df: pd.DataFrame) -> pd.Series: - kind = node.kind + if isinstance(node, FeatureNode): + return df[node.name] - if kind == "feature": - feat = node.args[0] - feat_name = feat.kind if isinstance(feat, Node) else str(feat) - return df[feat_name] - - if kind == "indicator": - name_node = node.args[0] - ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node) - params = [a for a in node.args[1:] if not isinstance(a, Node)] - fn = INDICATOR_FNS[ind_name] - result: pd.Series = fn(df, *params) + if isinstance(node, IndicatorNode): + fn = INDICATOR_FNS[node.name] + result: pd.Series = fn(df, *node.params) return result - if kind == "gt": - a = _eval_arg(node.args[0], df) - b = _eval_arg(node.args[1], df) - return _compare_with_nan(a > b, a, b) + if isinstance(node, LiteralNode): + return _to_series(node.value, df) - if kind == "lt": - a = _eval_arg(node.args[0], df) - b = _eval_arg(node.args[1], df) - return _compare_with_nan(a < b, a, b) + if isinstance(node, OpNode): + op = node.op + if op == "gt": + a = _eval_node(node.args[0], df) + b = _eval_node(node.args[1], df) + return _compare_with_nan(a > b, a, b) + if op == "lt": + a = _eval_node(node.args[0], df) + b = _eval_node(node.args[1], df) + return _compare_with_nan(a < b, a, b) + if op == "eq": + a = _eval_node(node.args[0], df) + b = _eval_node(node.args[1], df) + return _compare_with_nan(a == b, a, b) + if op == "and": + result = pd.Series(True, index=df.index) + for a in node.args: + result &= _eval_bool_arg(a, df) + return result + if op == "or": + result = pd.Series(False, index=df.index) + for a in node.args: + result |= _eval_bool_arg(a, df) + return result + if op == "not": + return ~_eval_bool_arg(node.args[0], df) + if op == "crossover": + a = _eval_node(node.args[0], df) + b = _eval_node(node.args[1], df) + return ((a > b) & (a.shift() <= b.shift())).fillna(False).astype(bool) + if op == "crossunder": + a = _eval_node(node.args[0], df) + b = _eval_node(node.args[1], df) + return ((a < b) & (a.shift() >= b.shift())).fillna(False).astype(bool) + raise RuntimeError(f"unsupported op in compiler: {op}") - if kind == "eq": - a = _eval_arg(node.args[0], df) - b = _eval_arg(node.args[1], df) - return _compare_with_nan(a == b, a, b) - - if kind == "and": - result = pd.Series(True, index=df.index) - for a in node.args: - result &= _eval_bool_arg(a, df) - return result - - if kind == "or": - result = pd.Series(False, index=df.index) - for a in node.args: - result |= _eval_bool_arg(a, df) - return result - - if kind == "not": - s = _eval_bool_arg(node.args[0], df) - return ~s - - if kind == "crossover": - a = _eval_arg(node.args[0], df) - b = _eval_arg(node.args[1], df) - return ((a > b) & (a.shift() <= b.shift())).fillna(False).astype(bool) - - if kind == "crossunder": - a = _eval_arg(node.args[0], df) - b = _eval_arg(node.args[1], df) - return ((a < b) & (a.shift() >= b.shift())).fillna(False).astype(bool) - - raise RuntimeError(f"unsupported node in compiler: {kind}") + raise RuntimeError(f"unsupported node type in compiler: {type(node).__name__}") _ACTION_TO_SIDE: dict[str, Side] = { @@ -199,10 +191,6 @@ _ACTION_TO_SIDE: dict[str, Side] = { } -def _action_to_side(action: Node) -> Side: - return _ACTION_TO_SIDE[action.kind] - - def compile_strategy(strategy: Strategy) -> Callable[[pd.DataFrame], pd.Series]: """Compile a :class:`Strategy` AST into a ``df -> Series[Side]`` callable. @@ -218,7 +206,7 @@ def compile_strategy(strategy: Strategy) -> Callable[[pd.DataFrame], pd.Series]: any_rule_seen = pd.Series(False, index=df.index) for rule in strategy.rules: match = _eval_node(rule.condition, df) - target = _action_to_side(rule.action) + target = _ACTION_TO_SIDE[rule.action] valid = ~_isna_series(match) any_rule_seen |= valid match_bool = match.where(valid, False).astype(bool) diff --git a/src/multi_swarm/protocol/grammar.py b/src/multi_swarm/protocol/grammar.py index c3d7422..52ff752 100644 --- a/src/multi_swarm/protocol/grammar.py +++ b/src/multi_swarm/protocol/grammar.py @@ -1,26 +1,27 @@ from __future__ import annotations -VERBS: frozenset[str] = frozenset( - { - "entry-long", - "entry-short", - "exit", - "flat", - "when", - "and", - "or", - "not", - "gt", - "lt", - "eq", - "feature", - "indicator", - "crossover", - "crossunder", - } +# Grammatica JSON Schema (Phase 1, post S-expression refactor). +# +# Distinzione strutturale: +# * Nodi OPERATORE -> dict con chiave ``"op"`` (logici, comparatori, crossover) +# * Nodi LEAF -> dict con chiave ``"kind"`` (indicator, feature, literal) +# ``op`` e ``kind`` sono mutuamente esclusivi sullo stesso nodo. + +LOGICAL_OPS: frozenset[str] = frozenset({"and", "or", "not"}) +COMPARATOR_OPS: frozenset[str] = frozenset({"gt", "lt", "eq"}) +CROSSOVER_OPS: frozenset[str] = frozenset({"crossover", "crossunder"}) + +ACTION_VALUES: frozenset[str] = frozenset( + {"entry-long", "entry-short", "exit", "flat"} +) +KIND_VALUES: frozenset[str] = frozenset({"indicator", "feature", "literal"}) + +KNOWN_INDICATORS: frozenset[str] = frozenset( + {"sma", "rsi", "atr", "macd", "realized_vol"} +) +KNOWN_FEATURES: frozenset[str] = frozenset( + {"open", "high", "low", "close", "volume"} ) -ACTION_VERBS: frozenset[str] = frozenset({"entry-long", "entry-short", "exit", "flat"}) -LOGICAL_VERBS: frozenset[str] = frozenset({"and", "or", "not"}) -COMPARATOR_VERBS: frozenset[str] = frozenset({"gt", "lt", "eq"}) -DATA_VERBS: frozenset[str] = frozenset({"feature", "indicator", "crossover", "crossunder"}) +# Convenience union (utile a validator / parser). +ALL_OPS: frozenset[str] = LOGICAL_OPS | COMPARATOR_OPS | CROSSOVER_OPS diff --git a/src/multi_swarm/protocol/parser.py b/src/multi_swarm/protocol/parser.py index 0d8cb59..31eb609 100644 --- a/src/multi_swarm/protocol/parser.py +++ b/src/multi_swarm/protocol/parser.py @@ -1,96 +1,203 @@ +"""JSON-based parser per la strategia di trading (Phase 1). + +L'AST è una piccola gerarchia di dataclass: + +* :class:`Strategy` è il top-level (lista di :class:`Rule`). +* :class:`Rule` accoppia una condizione (Node) ad un'azione (str). +* :class:`Node` è un'unione: nodi operatore (:class:`OpNode`) e nodi leaf + (:class:`IndicatorNode`, :class:`FeatureNode`, :class:`LiteralNode`). + +Convenzione di shape sui dict in input: + +* Nodi operatore: ``{"op": "", "args": [, ...]}``. +* Nodi indicator: ``{"kind": "indicator", "name": "", "params": [, ...]}``. +* Nodi feature: ``{"kind": "feature", "name": ""}``. +* Nodi literal: ``{"kind": "literal", "value": }``. +""" + from __future__ import annotations +import json from dataclasses import dataclass, field from typing import Any -import sexpdata # type: ignore[import-untyped] - -from .grammar import ACTION_VERBS, VERBS +from .grammar import ( + ACTION_VALUES, + ALL_OPS, +) class ParseError(Exception): - """Raised when an S-expression strategy cannot be parsed.""" + """Raised when a JSON strategy cannot be parsed into a valid AST.""" + + +# --------------------------------------------------------------------------- +# Dataclass AST +# --------------------------------------------------------------------------- @dataclass -class Node: - kind: str - args: list[Any] = field(default_factory=list) +class OpNode: + """Operator node: logical / comparator / crossover.""" + + op: str + args: list[Node] = field(default_factory=list) + + +@dataclass +class IndicatorNode: + """Leaf: indicatore tecnico calcolato sul dataframe OHLCV.""" + + name: str + params: list[float] = field(default_factory=list) + + +@dataclass +class FeatureNode: + """Leaf: colonna OHLCV (open/high/low/close/volume).""" + + name: str + + +@dataclass +class LiteralNode: + """Leaf: costante numerica.""" + + value: float + + +Node = OpNode | IndicatorNode | FeatureNode | LiteralNode @dataclass class Rule: - kind: str # always "when" condition: Node - action: Node + action: str @dataclass class Strategy: - kind: str # always "strategy" rules: list[Rule] -def _to_node(token: Any) -> Node | float | int | str: - """Convert a sexpdata token tree into a Node (or scalar leaf).""" - if isinstance(token, sexpdata.Symbol): - name = str(token.value()) - # Bare symbols inside expressions (e.g. `rsi` in (indicator rsi 14)) - # are kept as Node-with-no-args so callers can introspect uniformly. - return Node(kind=name, args=[]) - if isinstance(token, list): - if not token: - raise ParseError("Empty s-expression") - head = token[0] - if not isinstance(head, sexpdata.Symbol): - raise ParseError(f"Non-symbol head: {head!r}") - name = str(head.value()) - if name not in VERBS: - raise ParseError(f"Unknown verb: {name}") - return Node(kind=name, args=[_to_node(arg) for arg in token[1:]]) - # numeric / string literals pass through unchanged - return token # type: ignore[no-any-return] +# --------------------------------------------------------------------------- +# Conversione dict -> Node +# --------------------------------------------------------------------------- + + +def _to_node(obj: Any) -> Node: + if not isinstance(obj, dict): + raise ParseError(f"Node must be a JSON object, got {type(obj).__name__}") + + has_op = "op" in obj + has_kind = "kind" in obj + if has_op and has_kind: + raise ParseError( + "Node cannot define both 'op' and 'kind' (mutually exclusive)" + ) + if not has_op and not has_kind: + raise ParseError("Node must define either 'op' or 'kind'") + + if has_op: + op = obj["op"] + if not isinstance(op, str): + raise ParseError(f"'op' must be a string, got {type(op).__name__}") + if op not in ALL_OPS: + raise ParseError(f"Unknown op: {op!r}") + raw_args = obj.get("args") + if not isinstance(raw_args, list): + raise ParseError(f"Operator '{op}' missing 'args' list") + args = [_to_node(a) for a in raw_args] + return OpNode(op=op, args=args) + + # leaf node + kind = obj["kind"] + if not isinstance(kind, str): + raise ParseError(f"'kind' must be a string, got {type(kind).__name__}") + + if kind == "indicator": + name = obj.get("name") + if not isinstance(name, str): + raise ParseError("indicator node requires string 'name'") + raw_params = obj.get("params", []) + if not isinstance(raw_params, list): + raise ParseError("indicator 'params' must be a list") + params: list[float] = [] + for p in raw_params: + if isinstance(p, bool) or not isinstance(p, (int, float)): + raise ParseError( + f"indicator '{name}' params accept only numbers, got {p!r}" + ) + params.append(float(p)) + return IndicatorNode(name=name, params=params) + + if kind == "feature": + name = obj.get("name") + if not isinstance(name, str): + raise ParseError("feature node requires string 'name'") + return FeatureNode(name=name) + + if kind == "literal": + if "value" not in obj: + raise ParseError("literal node requires 'value'") + value = obj["value"] + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise ParseError(f"literal value must be numeric, got {value!r}") + return LiteralNode(value=float(value)) + + raise ParseError(f"Unknown leaf kind: {kind!r}") + + +# --------------------------------------------------------------------------- +# Top-level parser +# --------------------------------------------------------------------------- def parse_strategy(src: str) -> Strategy: - """Parse an S-expression strategy string into a Strategy AST. + """Parse a JSON strategy string into a :class:`Strategy` AST. - The grammar is documented in :mod:`multi_swarm.protocol.grammar` and is - intentionally tiny (15 verbs). We delegate raw S-expr lexing to - :mod:`sexpdata`, then validate the verb set ourselves. + Lo schema atteso è:: + + { + "rules": [ + {"condition": , "action": ""}, + ... + ] + } + + Raise :class:`ParseError` su JSON malformato o struttura inattesa. """ try: - parsed = sexpdata.loads(src) - except Exception as e: # sexpdata raises various exception types - raise ParseError(f"sexp parse error: {e}") from e + parsed = json.loads(src) + except json.JSONDecodeError as e: + raise ParseError(f"invalid JSON: {e}") from e - if not isinstance(parsed, list) or not parsed: - raise ParseError("Top-level must be (strategy ...)") - head = parsed[0] - if not isinstance(head, sexpdata.Symbol) or str(head.value()) != "strategy": - raise ParseError("Top-level must start with 'strategy'") - - raw_rules = parsed[1:] + if not isinstance(parsed, dict): + raise ParseError("Top-level must be a JSON object with 'rules'") + if "rules" not in parsed: + raise ParseError("Top-level object must contain 'rules' key") + raw_rules = parsed["rules"] + if not isinstance(raw_rules, list): + raise ParseError("'rules' must be a list") if not raw_rules: raise ParseError("Strategy must contain at least one rule") rules: list[Rule] = [] for raw in raw_rules: - if not isinstance(raw, list) or len(raw) != 3: - raise ParseError(f"Rule must be (when ): {raw!r}") - head_r = raw[0] - if not isinstance(head_r, sexpdata.Symbol) or str(head_r.value()) != "when": - raise ParseError(f"Rule must start with 'when': {raw!r}") - cond = _to_node(raw[1]) - action = _to_node(raw[2]) - if not isinstance(cond, Node): - raise ParseError(f"Condition must be a node: {cond!r}") - if not isinstance(action, Node): - raise ParseError(f"Action must be a node: {action!r}") - if action.kind not in ACTION_VERBS: + if not isinstance(raw, dict): + raise ParseError(f"Rule must be a JSON object, got {raw!r}") + if "condition" not in raw or "action" not in raw: raise ParseError( - f"Action must be one of {sorted(ACTION_VERBS)}, got {action.kind!r}" + f"Rule must contain 'condition' and 'action' keys: {raw!r}" ) - rules.append(Rule(kind="when", condition=cond, action=action)) + action = raw["action"] + if not isinstance(action, str): + raise ParseError(f"action must be a string, got {action!r}") + if action not in ACTION_VALUES: + raise ParseError( + f"action must be one of {sorted(ACTION_VALUES)}, got {action!r}" + ) + cond = _to_node(raw["condition"]) + rules.append(Rule(condition=cond, action=action)) - return Strategy(kind="strategy", rules=rules) + return Strategy(rules=rules) diff --git a/src/multi_swarm/protocol/validator.py b/src/multi_swarm/protocol/validator.py index 0e7ade7..439736c 100644 --- a/src/multi_swarm/protocol/validator.py +++ b/src/multi_swarm/protocol/validator.py @@ -1,20 +1,41 @@ +"""Semantic validation for the JSON-based strategy AST. + +Il parser garantisce già shape sintattica (op vs kind, struttura args/params, +tipi base). Qui si controllano vincoli semantici di Phase 1: + +* Arity di operatori logici / comparatori / crossover. +* Whitelist indicator + arity dei params. +* Whitelist feature. +* Niente nesting di indicator (params puramente numerici, garantito già dal + parser ma ricontrollato esplicitamente per chiarezza). +""" + from __future__ import annotations -from .grammar import COMPARATOR_VERBS, LOGICAL_VERBS -from .parser import Node, Strategy - -KNOWN_INDICATORS: frozenset[str] = frozenset({"sma", "rsi", "atr", "macd", "realized_vol"}) -KNOWN_FEATURES: frozenset[str] = frozenset({"open", "high", "low", "close", "volume"}) +from .grammar import ( + COMPARATOR_OPS, + CROSSOVER_OPS, + KNOWN_FEATURES, + KNOWN_INDICATORS, + LOGICAL_OPS, +) +from .parser import ( + FeatureNode, + IndicatorNode, + LiteralNode, + Node, + OpNode, + Strategy, +) # Numero di parametri numerici accettati dopo il nome dell'indicatore. -# La tupla (min, max) include solo i numeri (gli argomenti di tipo Node sono -# proibiti dal compiler - gli indicatori non sono annidabili in Phase 1). +# (min, max) sui soli numeri. Indicatori non sono annidabili in Phase 1. INDICATOR_ARITY: dict[str, tuple[int, int]] = { "sma": (1, 1), # length "rsi": (1, 1), # length "atr": (1, 1), # length "realized_vol": (1, 1), # window - "macd": (0, 3), # fast, slow, signal (tutti opzionali con default) + "macd": (0, 3), # fast, slow, signal (tutti opzionali) } @@ -23,77 +44,66 @@ class ValidationError(Exception): def validate_strategy(strategy: Strategy) -> None: - """Check semantic constraints on a parsed Strategy AST. - - The parser already enforces verb-set membership; this pass adds: - * arity checks for logical/comparator/data verbs, - * known-indicator / known-feature whitelists. - """ + """Walk every rule of the strategy and assert semantic constraints.""" for rule in strategy.rules: - _validate_node(rule.condition, _expect_bool=True) + _validate_node(rule.condition) -def _validate_node(node: Node, _expect_bool: bool) -> None: - if node.kind in LOGICAL_VERBS: - if node.kind == "not": - if len(node.args) != 1: - raise ValidationError(f"'not' needs 1 arg, got {len(node.args)}") - arg = node.args[0] - if isinstance(arg, Node): - _validate_node(arg, _expect_bool=True) +def _validate_node(node: Node) -> None: + if isinstance(node, OpNode): + _validate_op(node) + return + if isinstance(node, IndicatorNode): + _validate_indicator(node) + return + if isinstance(node, FeatureNode): + if node.name not in KNOWN_FEATURES: + raise ValidationError(f"unknown feature: {node.name}") + return + if isinstance(node, LiteralNode): + # parser ha già validato il tipo numerico + return + raise ValidationError(f"unexpected node type: {type(node).__name__}") + + +def _validate_op(node: OpNode) -> None: + op = node.op + n = len(node.args) + + if op in LOGICAL_OPS: + if op == "not": + if n != 1: + raise ValidationError(f"'not' needs 1 arg, got {n}") else: - if len(node.args) < 2: - raise ValidationError(f"'{node.kind}' needs >=2 args") - for a in node.args: - if isinstance(a, Node): - _validate_node(a, _expect_bool=True) - return - - if node.kind in COMPARATOR_VERBS: - if len(node.args) != 2: - raise ValidationError(f"'{node.kind}' needs 2 args, got {len(node.args)}") + if n < 2: + raise ValidationError(f"'{op}' needs >=2 args, got {n}") for a in node.args: - if isinstance(a, Node): - _validate_node(a, _expect_bool=False) + _validate_node(a) return - if node.kind in {"crossover", "crossunder"}: - if len(node.args) != 2: - raise ValidationError(f"'{node.kind}' needs 2 args") + if op in COMPARATOR_OPS: + if n != 2: + raise ValidationError(f"'{op}' needs 2 args, got {n}") for a in node.args: - if isinstance(a, Node): - _validate_node(a, _expect_bool=False) + _validate_node(a) return - if node.kind == "indicator": - if len(node.args) < 1: - raise ValidationError("'indicator' needs >=1 args (name [, params...])") - name_node = node.args[0] - ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node) - if ind_name not in KNOWN_INDICATORS: - raise ValidationError(f"unknown indicator: {ind_name}") - # Gli indicatori non accettano Node come params (no-nesting in Phase 1). - for a in node.args[1:]: - if isinstance(a, Node): - raise ValidationError( - f"indicator '{ind_name}' does not accept nested expressions; " - f"only numeric literals (got node {a.kind})" - ) - n_params = len(node.args) - 1 - min_p, max_p = INDICATOR_ARITY[ind_name] - if not (min_p <= n_params <= max_p): - raise ValidationError( - f"indicator '{ind_name}' arity {n_params} out of [{min_p},{max_p}]" - ) + if op in CROSSOVER_OPS: + if n != 2: + raise ValidationError(f"'{op}' needs 2 args, got {n}") + for a in node.args: + _validate_node(a) return - if node.kind == "feature": - if len(node.args) != 1: - raise ValidationError("'feature' needs 1 arg") - feat_node = node.args[0] - feat_name = feat_node.kind if isinstance(feat_node, Node) else str(feat_node) - if feat_name not in KNOWN_FEATURES: - raise ValidationError(f"unknown feature: {feat_name}") - return + raise ValidationError(f"unexpected op in expression: {op}") - raise ValidationError(f"unexpected node kind in expression: {node.kind}") + +def _validate_indicator(node: IndicatorNode) -> None: + if node.name not in KNOWN_INDICATORS: + raise ValidationError(f"unknown indicator: {node.name}") + n_params = len(node.params) + min_p, max_p = INDICATOR_ARITY[node.name] + if not (min_p <= n_params <= max_p): + raise ValidationError( + f"indicator '{node.name}' arity {n_params} out of [{min_p},{max_p}]" + ) diff --git a/tests/integration/test_e2e_minimal_run.py b/tests/integration/test_e2e_minimal_run.py index b01aec0..bca4765 100644 --- a/tests/integration/test_e2e_minimal_run.py +++ b/tests/integration/test_e2e_minimal_run.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import numpy as np @@ -26,16 +27,40 @@ def synthetic_ohlcv(): ) +_STRATEGY_PAYLOAD = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + }, + { + "condition": { + "op": "lt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 30.0}, + ], + }, + "action": "entry-long", + }, + ] + } +) + + @pytest.fixture def fake_llm(mocker): - """LLM mock che ritorna sempre una strategia valida.""" + """LLM mock che ritorna sempre una strategia JSON valida.""" fake = mocker.MagicMock() fake.complete.return_value = CompletionResult( - text=( - "```lisp\n(strategy " - "(when (gt (indicator rsi 14) 70.0) (entry-short)) " - "(when (lt (indicator rsi 14) 30.0) (entry-long)))\n```" - ), + text="```json\n" + _STRATEGY_PAYLOAD + "\n```", input_tokens=200, output_tokens=80, tier=ModelTier.C, diff --git a/tests/unit/test_adversarial.py b/tests/unit/test_adversarial.py index feb94a0..5d7591a 100644 --- a/tests/unit/test_adversarial.py +++ b/tests/unit/test_adversarial.py @@ -1,3 +1,5 @@ +import json + import numpy as np import pandas as pd import pytest @@ -23,7 +25,22 @@ def ohlcv() -> pd.DataFrame: def test_degenerate_always_long_flagged(ohlcv: pd.DataFrame) -> None: - src = "(strategy (when (gt (feature close) -1e9) (entry-long)))" + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "feature", "name": "close"}, + {"kind": "literal", "value": -1e9}, + ], + }, + "action": "entry-long", + } + ] + } + ) ast = parse_strategy(src) agent = AdversarialAgent() report = agent.review(ast, ohlcv) @@ -32,10 +49,31 @@ def test_degenerate_always_long_flagged(ohlcv: pd.DataFrame) -> None: def test_no_findings_on_reasonable_strategy(ohlcv: pd.DataFrame) -> None: - src = ( - "(strategy " - "(when (gt (indicator rsi 14) 70.0) (entry-short)) " - "(when (lt (indicator rsi 14) 30.0) (entry-long)))" + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + }, + { + "condition": { + "op": "lt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 30.0}, + ], + }, + "action": "entry-long", + }, + ] + } ) ast = parse_strategy(src) agent = AdversarialAgent() @@ -45,7 +83,22 @@ def test_no_findings_on_reasonable_strategy(ohlcv: pd.DataFrame) -> None: def test_zero_trade_strategy_flagged(ohlcv: pd.DataFrame) -> None: - src = "(strategy (when (gt (feature close) 1e9) (entry-long)))" + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "feature", "name": "close"}, + {"kind": "literal", "value": 1e9}, + ], + }, + "action": "entry-long", + } + ] + } + ) ast = parse_strategy(src) agent = AdversarialAgent() report = agent.review(ast, ohlcv) diff --git a/tests/unit/test_falsification.py b/tests/unit/test_falsification.py index c7fb256..3d59173 100644 --- a/tests/unit/test_falsification.py +++ b/tests/unit/test_falsification.py @@ -1,3 +1,5 @@ +import json + import numpy as np import pandas as pd import pytest @@ -23,10 +25,31 @@ def trending_ohlcv() -> pd.DataFrame: def test_falsification_returns_report(trending_ohlcv: pd.DataFrame) -> None: - src = ( - "(strategy " - "(when (gt (indicator rsi 14) 70.0) (entry-short)) " - "(when (lt (indicator rsi 14) 30.0) (entry-long)))" + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + }, + { + "condition": { + "op": "lt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 30.0}, + ], + }, + "action": "entry-long", + }, + ] + } ) ast = parse_strategy(src) agent = FalsificationAgent(fees_bp=5.0, n_trials_dsr=20) @@ -40,7 +63,22 @@ def test_falsification_returns_report(trending_ohlcv: pd.DataFrame) -> None: def test_falsification_zero_trades_returns_zero_metrics(trending_ohlcv: pd.DataFrame) -> None: - src = "(strategy (when (gt (feature close) 1e9) (entry-long)))" + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "feature", "name": "close"}, + {"kind": "literal", "value": 1e9}, + ], + }, + "action": "entry-long", + } + ] + } + ) ast = parse_strategy(src) agent = FalsificationAgent(fees_bp=5.0, n_trials_dsr=20) report = agent.evaluate(ast, trending_ohlcv) diff --git a/tests/unit/test_hypothesis_agent.py b/tests/unit/test_hypothesis_agent.py index 632a316..c37dd54 100644 --- a/tests/unit/test_hypothesis_agent.py +++ b/tests/unit/test_hypothesis_agent.py @@ -1,3 +1,5 @@ +import json + from multi_swarm.agents.hypothesis import HypothesisAgent, MarketSummary from multi_swarm.genome.hypothesis import HypothesisAgentGenome, ModelTier from multi_swarm.llm.client import CompletionResult @@ -16,16 +18,26 @@ def make_summary() -> MarketSummary: ) -def test_hypothesis_agent_calls_llm_and_parses(mocker): # type: ignore[no-untyped-def] - fake_llm = mocker.MagicMock() - fake_llm.complete.return_value = CompletionResult( - text="(strategy (when (gt (indicator rsi 14) 70.0) (entry-short)))", - input_tokens=200, - output_tokens=80, - tier=ModelTier.C, - model="qwen", - ) - g = HypothesisAgentGenome( +VALID_STRATEGY_JSON = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + } + ] + } +) + + +def make_genome() -> HypothesisAgentGenome: + return HypothesisAgentGenome( system_prompt="Pensa come un fisico.", feature_access=["close"], temperature=0.9, @@ -34,10 +46,20 @@ def test_hypothesis_agent_calls_llm_and_parses(mocker): # type: ignore[no-untyp lookback_window=200, cognitive_style="physicist", ) + + +def test_hypothesis_agent_calls_llm_and_parses(mocker): # type: ignore[no-untyped-def] + fake_llm = mocker.MagicMock() + fake_llm.complete.return_value = CompletionResult( + text=VALID_STRATEGY_JSON, + input_tokens=200, + output_tokens=80, + tier=ModelTier.C, + model="qwen", + ) agent = HypothesisAgent(llm=fake_llm) - proposal = agent.propose(g, make_summary()) + proposal = agent.propose(make_genome(), make_summary()) assert proposal.strategy is not None - assert proposal.raw_text.startswith("(strategy") assert proposal.completion.input_tokens == 200 fake_llm.complete.assert_called_once() @@ -45,49 +67,64 @@ def test_hypothesis_agent_calls_llm_and_parses(mocker): # type: ignore[no-untyp def test_hypothesis_agent_returns_none_on_parse_error(mocker): # type: ignore[no-untyped-def] fake_llm = mocker.MagicMock() fake_llm.complete.return_value = CompletionResult( - text="this is not s-expression", + text="this is not JSON", input_tokens=200, output_tokens=80, tier=ModelTier.C, model="qwen", ) - g = HypothesisAgentGenome( - system_prompt="x", - feature_access=["close"], - temperature=0.9, - top_p=0.95, - model_tier=ModelTier.C, - lookback_window=200, - cognitive_style="physicist", - ) agent = HypothesisAgent(llm=fake_llm) - proposal = agent.propose(g, make_summary()) + proposal = agent.propose(make_genome(), make_summary()) assert proposal.strategy is None assert proposal.parse_error is not None -def test_hypothesis_agent_extracts_sexp_from_markdown_fence(mocker): # type: ignore[no-untyped-def] +def test_hypothesis_agent_extracts_json_from_markdown_fence(mocker): # type: ignore[no-untyped-def] + fenced = ( + "Ecco la strategia:\n```json\n" + + VALID_STRATEGY_JSON + + "\n```\nFatta." + ) fake_llm = mocker.MagicMock() fake_llm.complete.return_value = CompletionResult( - text=( - "Ecco la strategia:\n```lisp\n" - "(strategy (when (lt (indicator rsi 14) 30.0) (entry-long)))\n" - "```\nFatta." - ), + text=fenced, input_tokens=200, output_tokens=80, tier=ModelTier.C, model="qwen", ) - g = HypothesisAgentGenome( - system_prompt="x", - feature_access=["close"], - temperature=0.9, - top_p=0.95, - model_tier=ModelTier.C, - lookback_window=200, - cognitive_style="physicist", + agent = HypothesisAgent(llm=fake_llm) + proposal = agent.propose(make_genome(), make_summary()) + assert proposal.strategy is not None + + +def test_hypothesis_agent_returns_error_on_invalid_strategy(mocker): # type: ignore[no-untyped-def] + bad = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "wibble", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + } + ] + } + ) + fake_llm = mocker.MagicMock() + fake_llm.complete.return_value = CompletionResult( + text=bad, + input_tokens=200, + output_tokens=80, + tier=ModelTier.C, + model="qwen", ) agent = HypothesisAgent(llm=fake_llm) - proposal = agent.propose(g, make_summary()) - assert proposal.strategy is not None + proposal = agent.propose(make_genome(), make_summary()) + assert proposal.strategy is None + assert proposal.parse_error is not None + assert "wibble" in proposal.parse_error or "unknown" in proposal.parse_error diff --git a/tests/unit/test_protocol_compiler.py b/tests/unit/test_protocol_compiler.py index c244e7d..80726f3 100644 --- a/tests/unit/test_protocol_compiler.py +++ b/tests/unit/test_protocol_compiler.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + import numpy as np import pandas as pd import pytest @@ -26,7 +28,22 @@ def ohlcv() -> pd.DataFrame: def test_compile_simple_long(ohlcv: pd.DataFrame) -> None: - src = "(strategy (when (lt (indicator rsi 14) 100.0) (entry-long)))" + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "lt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 100.0}, + ], + }, + "action": "entry-long", + } + ] + } + ) ast = parse_strategy(src) fn = compile_strategy(ast) signals = fn(ohlcv) @@ -35,7 +52,22 @@ def test_compile_simple_long(ohlcv: pd.DataFrame) -> None: def test_compile_no_match_is_flat(ohlcv: pd.DataFrame) -> None: - src = "(strategy (when (gt (indicator rsi 14) 1000.0) (entry-long)))" + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 1000.0}, + ], + }, + "action": "entry-long", + } + ] + } + ) ast = parse_strategy(src) fn = compile_strategy(ast) signals = fn(ohlcv) @@ -43,11 +75,32 @@ def test_compile_no_match_is_flat(ohlcv: pd.DataFrame) -> None: def test_compile_two_rules_priority(ohlcv: pd.DataFrame) -> None: - src = """ - (strategy - (when (gt (feature close) 110.0) (entry-long)) - (when (lt (feature close) 105.0) (entry-short))) - """ + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "feature", "name": "close"}, + {"kind": "literal", "value": 110.0}, + ], + }, + "action": "entry-long", + }, + { + "condition": { + "op": "lt", + "args": [ + {"kind": "feature", "name": "close"}, + {"kind": "literal", "value": 105.0}, + ], + }, + "action": "entry-short", + }, + ] + } + ) ast = parse_strategy(src) fn = compile_strategy(ast) signals = fn(ohlcv) diff --git a/tests/unit/test_protocol_parser.py b/tests/unit/test_protocol_parser.py index 2313b91..1005013 100644 --- a/tests/unit/test_protocol_parser.py +++ b/tests/unit/test_protocol_parser.py @@ -1,47 +1,198 @@ +import json + import pytest -from multi_swarm.protocol.grammar import VERBS -from multi_swarm.protocol.parser import ParseError, parse_strategy +from multi_swarm.protocol.grammar import ( + ACTION_VALUES, + ALL_OPS, + COMPARATOR_OPS, + CROSSOVER_OPS, + KIND_VALUES, + LOGICAL_OPS, +) +from multi_swarm.protocol.parser import ( + FeatureNode, + IndicatorNode, + LiteralNode, + OpNode, + ParseError, + parse_strategy, +) -def test_grammar_has_15_verbs(): - assert len(VERBS) == 15 +def test_grammar_constant_sets() -> None: + assert LOGICAL_OPS == {"and", "or", "not"} + assert COMPARATOR_OPS == {"gt", "lt", "eq"} + assert CROSSOVER_OPS == {"crossover", "crossunder"} + assert KIND_VALUES == {"indicator", "feature", "literal"} + assert ACTION_VALUES == {"entry-long", "entry-short", "exit", "flat"} + assert ALL_OPS == LOGICAL_OPS | COMPARATOR_OPS | CROSSOVER_OPS -def test_parse_simple_strategy(): - src = "(strategy (when (gt (indicator rsi 14) 70.0) (entry-short)))" +def test_parse_simple_strategy() -> None: + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + } + ] + } + ) ast = parse_strategy(src) - assert ast.kind == "strategy" assert len(ast.rules) == 1 rule = ast.rules[0] - assert rule.kind == "when" - assert rule.condition.kind == "gt" - assert rule.action.kind == "entry-short" + assert rule.action == "entry-short" + assert isinstance(rule.condition, OpNode) + assert rule.condition.op == "gt" + assert isinstance(rule.condition.args[0], IndicatorNode) + assert rule.condition.args[0].name == "rsi" + assert rule.condition.args[0].params == [14.0] + assert isinstance(rule.condition.args[1], LiteralNode) + assert rule.condition.args[1].value == 70.0 -def test_parse_multiple_rules(): - src = """ - (strategy - (when (gt (indicator rsi 14) 70.0) (entry-short)) - (when (lt (indicator rsi 14) 30.0) (entry-long))) - """ +def test_parse_multiple_rules() -> None: + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + "action": "entry-short", + }, + { + "condition": { + "op": "lt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 30.0}, + ], + }, + "action": "entry-long", + }, + ] + } + ) ast = parse_strategy(src) assert len(ast.rules) == 2 -def test_parse_unknown_verb_raises(): - src = "(strategy (when (frobnicate 1 2) (entry-long)))" - with pytest.raises(ParseError): +def test_parse_feature_leaf() -> None: + src = json.dumps( + { + "rules": [ + { + "condition": { + "op": "crossover", + "args": [ + {"kind": "feature", "name": "close"}, + {"kind": "indicator", "name": "sma", "params": [50]}, + ], + }, + "action": "entry-long", + } + ] + } + ) + ast = parse_strategy(src) + cond = ast.rules[0].condition + assert isinstance(cond, OpNode) and cond.op == "crossover" + assert isinstance(cond.args[0], FeatureNode) + assert cond.args[0].name == "close" + + +def test_parse_unknown_op_raises() -> None: + src = json.dumps( + { + "rules": [ + { + "condition": {"op": "frobnicate", "args": [1, 2]}, + "action": "entry-long", + } + ] + } + ) + with pytest.raises(ParseError, match="Unknown op"): parse_strategy(src) -def test_parse_malformed_raises(): - src = "(strategy (when" - with pytest.raises(ParseError): +def test_parse_invalid_action_raises() -> None: + src = json.dumps( + { + "rules": [ + { + "condition": {"kind": "literal", "value": 1.0}, + "action": "buy-now", + } + ] + } + ) + with pytest.raises(ParseError, match="action"): parse_strategy(src) -def test_parse_empty_strategy_raises(): - src = "(strategy)" - with pytest.raises(ParseError): +def test_parse_malformed_json_raises() -> None: + with pytest.raises(ParseError, match="invalid JSON"): + parse_strategy("{this is not json") + + +def test_parse_top_level_array_raises() -> None: + with pytest.raises(ParseError, match="JSON object"): + parse_strategy("[1, 2, 3]") + + +def test_parse_missing_rules_key_raises() -> None: + with pytest.raises(ParseError, match="rules"): + parse_strategy(json.dumps({"foo": "bar"})) + + +def test_parse_empty_rules_raises() -> None: + with pytest.raises(ParseError, match="at least one"): + parse_strategy(json.dumps({"rules": []})) + + +def test_parse_node_with_both_op_and_kind_raises() -> None: + src = json.dumps( + { + "rules": [ + { + "condition": {"op": "gt", "kind": "indicator", "args": []}, + "action": "flat", + } + ] + } + ) + with pytest.raises(ParseError, match="mutually exclusive"): + parse_strategy(src) + + +def test_parse_indicator_with_nested_node_raises() -> None: + src = json.dumps( + { + "rules": [ + { + "condition": { + "kind": "indicator", + "name": "sma", + "params": [{"kind": "literal", "value": 14}], + }, + "action": "flat", + } + ] + } + ) + with pytest.raises(ParseError, match="params"): parse_strategy(src) diff --git a/tests/unit/test_protocol_validator.py b/tests/unit/test_protocol_validator.py index 2aa6b4f..472b885 100644 --- a/tests/unit/test_protocol_validator.py +++ b/tests/unit/test_protocol_validator.py @@ -1,38 +1,153 @@ +import json + import pytest from multi_swarm.protocol.parser import parse_strategy from multi_swarm.protocol.validator import ValidationError, validate_strategy +def _wrap(condition: dict, action: str = "entry-long") -> str: + return json.dumps({"rules": [{"condition": condition, "action": action}]}) + + def test_valid_strategy_passes() -> None: - src = "(strategy (when (gt (indicator rsi 14) 70.0) (entry-short)))" + src = _wrap( + { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + }, + action="entry-short", + ) ast = parse_strategy(src) validate_strategy(ast) # no exception def test_indicator_unknown_name_fails() -> None: - src = "(strategy (when (gt (indicator wibble 14) 70.0) (entry-short)))" + src = _wrap( + { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "wibble", "params": [14]}, + {"kind": "literal", "value": 70.0}, + ], + } + ) ast = parse_strategy(src) with pytest.raises(ValidationError, match="unknown indicator"): validate_strategy(ast) -def test_indicator_wrong_arity_fails() -> None: - src = "(strategy (when (gt (indicator rsi) 70.0) (entry-short)))" +def test_indicator_arity_too_few_fails() -> None: + src = _wrap( + { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": []}, + {"kind": "literal", "value": 70.0}, + ], + } + ) ast = parse_strategy(src) - with pytest.raises(ValidationError): + with pytest.raises(ValidationError, match="arity"): + validate_strategy(ast) + + +def test_indicator_arity_too_many_fails() -> None: + src = _wrap( + { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "rsi", "params": [14, 28]}, + {"kind": "literal", "value": 70.0}, + ], + } + ) + ast = parse_strategy(src) + with pytest.raises(ValidationError, match="arity"): + validate_strategy(ast) + + +def test_macd_arity_zero_to_three_ok() -> None: + for params in [[], [12], [12, 26], [12, 26, 9]]: + src = _wrap( + { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "macd", "params": params}, + {"kind": "literal", "value": 0.0}, + ], + } + ) + ast = parse_strategy(src) + validate_strategy(ast) + + +def test_macd_arity_four_fails() -> None: + src = _wrap( + { + "op": "gt", + "args": [ + {"kind": "indicator", "name": "macd", "params": [1, 2, 3, 4]}, + {"kind": "literal", "value": 0.0}, + ], + } + ) + ast = parse_strategy(src) + with pytest.raises(ValidationError, match="arity"): validate_strategy(ast) def test_comparator_wrong_arity_fails() -> None: - src = "(strategy (when (gt 1.0) (entry-long)))" + src = _wrap({"op": "gt", "args": [{"kind": "literal", "value": 1.0}]}) ast = parse_strategy(src) - with pytest.raises(ValidationError): + with pytest.raises(ValidationError, match="needs 2 args"): + validate_strategy(ast) + + +def test_logical_not_arity_fails() -> None: + src = _wrap( + { + "op": "not", + "args": [ + {"kind": "literal", "value": 1.0}, + {"kind": "literal", "value": 2.0}, + ], + } + ) + ast = parse_strategy(src) + with pytest.raises(ValidationError, match="'not' needs 1"): + validate_strategy(ast) + + +def test_logical_and_arity_fails() -> None: + src = _wrap({"op": "and", "args": [{"kind": "literal", "value": 1.0}]}) + ast = parse_strategy(src) + with pytest.raises(ValidationError, match="and"): + validate_strategy(ast) + + +def test_crossover_wrong_arity_fails() -> None: + src = _wrap( + {"op": "crossover", "args": [{"kind": "literal", "value": 1.0}]} + ) + ast = parse_strategy(src) + with pytest.raises(ValidationError, match="crossover"): validate_strategy(ast) def test_feature_unknown_column_fails() -> None: - src = "(strategy (when (gt (feature wibble) 100.0) (entry-long)))" + src = _wrap( + { + "op": "gt", + "args": [ + {"kind": "feature", "name": "wibble"}, + {"kind": "literal", "value": 100.0}, + ], + } + ) ast = parse_strategy(src) with pytest.raises(ValidationError, match="unknown feature"): validate_strategy(ast) diff --git a/uv.lock b/uv.lock index 1a1306f..f6955e2 100644 --- a/uv.lock +++ b/uv.lock @@ -560,7 +560,6 @@ dependencies = [ { name = "pyyaml" }, { name = "requests" }, { name = "scipy" }, - { name = "sexpdata" }, { name = "sqlmodel" }, { name = "streamlit" }, { name = "tenacity" }, @@ -590,7 +589,6 @@ requires-dist = [ { name = "pyyaml", specifier = ">=6.0" }, { name = "requests", specifier = ">=2.32" }, { name = "scipy", specifier = ">=1.14" }, - { name = "sexpdata", specifier = ">=1.0.2" }, { name = "sqlmodel", specifier = ">=0.0.22" }, { name = "streamlit", specifier = ">=1.40" }, { name = "tenacity", specifier = ">=9.0" }, @@ -1321,15 +1319,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/39/338d9219c4e87f3e708f18857ecd24d22a0c3094752393319553096b98af/scipy-1.17.1-cp314-cp314t-win_arm64.whl", hash = "sha256:200e1050faffacc162be6a486a984a0497866ec54149a01270adc8a59b7c7d21", size = 25489165, upload-time = "2026-02-23T00:22:29.563Z" }, ] -[[package]] -name = "sexpdata" -version = "1.0.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a7/7f/369a478863a39351be75e0a12602bc29196b31f87bf3432bed2be6379f8e/sexpdata-1.0.2.tar.gz", hash = "sha256:92b67b0361f6766f8f9e44b9519cf3fbcfafa755db85bbf893c3e1cf4ddac109", size = 8906, upload-time = "2024-01-09T07:09:59.096Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f1/f3/ec9f8cc20dc1f34c926f0ec3f43b73fa2da59cf08e432fb8ae5b666b2027/sexpdata-1.0.2-py3-none-any.whl", hash = "sha256:b39c918f055a85c5c35c1d4f7930aabb176bd29016e5ba5692e7e849914b2a1a", size = 10337, upload-time = "2024-01-09T07:09:57.185Z" }, -] - [[package]] name = "six" version = "1.17.0"