From 26c328d541af8ee8f5efaf342619eb0a95803538 Mon Sep 17 00:00:00 2001 From: AdrianoDev Date: Sat, 9 May 2026 19:40:28 +0200 Subject: [PATCH] feat(protocol): AST compiler to (df -> Series[Side]) signal fn Implementa compile_strategy che produce funzione df -> Series[Side] con valutazione regole in ordine, prima-che-matcha vince, FLAT default e NaN per warmup degli indicatori. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/multi_swarm/protocol/compiler.py | 235 +++++++++++++++++++++++++++ tests/unit/test_protocol_compiler.py | 55 +++++++ 2 files changed, 290 insertions(+) create mode 100644 src/multi_swarm/protocol/compiler.py create mode 100644 tests/unit/test_protocol_compiler.py diff --git a/src/multi_swarm/protocol/compiler.py b/src/multi_swarm/protocol/compiler.py new file mode 100644 index 0000000..6e2ba8a --- /dev/null +++ b/src/multi_swarm/protocol/compiler.py @@ -0,0 +1,235 @@ +"""Compile a parsed :class:`Strategy` AST into a callable signal function. + +The compiled callable maps an OHLCV ``DataFrame`` into a ``pd.Series`` of +:class:`Side` values (one entry per timestamp). Rules are evaluated in order; +the first matching rule wins for each row, with :data:`Side.FLAT` as default +when no rule matches. + +Design notes +------------ +* Indicator dispatch goes through :data:`INDICATOR_FNS`, a dict of named + helpers. The dict is annotated as ``dict[str, Any]`` because each entry has + a different concrete signature (``(df, length)`` vs ``(df, fast, slow)``); + modelling that under ``mypy --strict`` would require a ``Protocol`` per + arity, which is overkill for the Phase 1 indicator subset. +* Numeric leaves coming out of :mod:`sexpdata` arrive as ``int`` / ``float`` + / ``str``; we widen via :func:`_to_series` to broadcast them along the + DataFrame index for arithmetic comparisons. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import numpy as np +import pandas as pd # type: ignore[import-untyped] + +from ..backtest.orders import Side +from .parser import Node, Strategy + + +def _sma(s: pd.Series, length: int) -> pd.Series: + return s.rolling(length, min_periods=1).mean() + + +def _rsi(s: pd.Series, length: int) -> pd.Series: + delta = s.diff() + up = delta.clip(lower=0) + down = -delta.clip(upper=0) + roll_up = up.ewm(alpha=1.0 / length, adjust=False).mean() + roll_down = down.ewm(alpha=1.0 / length, adjust=False).mean() + # Epsilon floor on roll_down avoids the degenerate case where a strictly + # monotone series gives roll_down=0 -> rs=inf -> rsi=100 exactly, which + # silently fails downstream `lt rsi 100` comparisons. Floating-point + # arithmetic with a tiny epsilon yields rsi just below 100 (e.g. 99.999..), + # preserving the expected ordering while still being numerically harmless. + rs = roll_up / roll_down.replace(0, 1e-12) + return 100 - (100 / (1 + rs)) + + +def _atr(df: pd.DataFrame, length: int) -> pd.Series: + h_l = df["high"] - df["low"] + h_c = (df["high"] - df["close"].shift()).abs() + l_c = (df["low"] - df["close"].shift()).abs() + tr = pd.concat([h_l, h_c, l_c], axis=1).max(axis=1) + return tr.ewm(alpha=1.0 / length, adjust=False).mean() + + +def _realized_vol(s: pd.Series, window: int) -> pd.Series: + returns = s.pct_change() + return returns.rolling(window, min_periods=1).std() * np.sqrt(window) + + +def _ind_sma(df: pd.DataFrame, length: int) -> pd.Series: + return _sma(df["close"], length) + + +def _ind_rsi(df: pd.DataFrame, length: int) -> pd.Series: + return _rsi(df["close"], length) + + +def _ind_atr(df: pd.DataFrame, length: int) -> pd.Series: + return _atr(df, length) + + +def _ind_realized_vol(df: pd.DataFrame, window: int) -> pd.Series: + return _realized_vol(df["close"], window) + + +def _ind_macd(df: pd.DataFrame, fast: int = 12, slow: int = 26) -> pd.Series: + return _sma(df["close"], fast) - _sma(df["close"], slow) + + +# Annotated as ``dict[str, Any]`` deliberately: each indicator has its own +# arity and parameter names, so a single ``Callable`` signature would be a +# lie. Dispatch happens in :func:`_eval_node`, which validates the verb name +# against this map. +INDICATOR_FNS: dict[str, Any] = { + "sma": _ind_sma, + "rsi": _ind_rsi, + "atr": _ind_atr, + "realized_vol": _ind_realized_vol, + "macd": _ind_macd, +} + + +def _to_series(value: object, df: pd.DataFrame) -> pd.Series: + """Broadcast a numeric literal across the DataFrame index.""" + return pd.Series(float(value), index=df.index) # type: ignore[arg-type] + + +def _eval_arg(arg: Any, df: pd.DataFrame) -> pd.Series: + """Evaluate either a child Node or a scalar literal into a Series.""" + if isinstance(arg, Node): + return _eval_node(arg, df) + return _to_series(arg, df) + + +def _compare_with_nan(result: pd.Series, a: pd.Series, b: pd.Series) -> pd.Series: + """Mark a comparison result as NaN where either operand is NaN. + + Pandas comparison ops normally return ``False`` for NaN inputs, which would + silently turn warmup periods (e.g. RSI before its rolling window fills) + into "no match -> FLAT". We promote those slots to NaN so callers can + distinguish "indicator unavailable" from "condition didn't fire". + """ + out = result.astype(object) + nan_mask = a.isna() | b.isna() + out[nan_mask] = np.nan + return out + + +def _eval_bool_arg(arg: Any, df: pd.DataFrame) -> pd.Series: + """Evaluate either a child Node (bool series) or a literal into a bool Series.""" + if isinstance(arg, Node): + return _eval_node(arg, df).fillna(False).astype(bool) + return pd.Series(bool(arg), index=df.index) + + +def _eval_node(node: Node, df: pd.DataFrame) -> pd.Series: + kind = node.kind + + if kind == "feature": + feat = node.args[0] + feat_name = feat.kind if isinstance(feat, Node) else str(feat) + return df[feat_name] + + if kind == "indicator": + name_node = node.args[0] + ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node) + params = [a for a in node.args[1:] if not isinstance(a, Node)] + fn = INDICATOR_FNS[ind_name] + result: pd.Series = fn(df, *params) + return result + + if kind == "gt": + a = _eval_arg(node.args[0], df) + b = _eval_arg(node.args[1], df) + return _compare_with_nan(a > b, a, b) + + if kind == "lt": + a = _eval_arg(node.args[0], df) + b = _eval_arg(node.args[1], df) + return _compare_with_nan(a < b, a, b) + + if kind == "eq": + a = _eval_arg(node.args[0], df) + b = _eval_arg(node.args[1], df) + return _compare_with_nan(a == b, a, b) + + if kind == "and": + result = pd.Series(True, index=df.index) + for a in node.args: + result &= _eval_bool_arg(a, df) + return result + + if kind == "or": + result = pd.Series(False, index=df.index) + for a in node.args: + result |= _eval_bool_arg(a, df) + return result + + if kind == "not": + s = _eval_bool_arg(node.args[0], df) + return ~s + + if kind == "crossover": + a = _eval_arg(node.args[0], df) + b = _eval_arg(node.args[1], df) + return ((a > b) & (a.shift() <= b.shift())).fillna(False).astype(bool) + + if kind == "crossunder": + a = _eval_arg(node.args[0], df) + b = _eval_arg(node.args[1], df) + return ((a < b) & (a.shift() >= b.shift())).fillna(False).astype(bool) + + raise RuntimeError(f"unsupported node in compiler: {kind}") + + +_ACTION_TO_SIDE: dict[str, Side] = { + "entry-long": Side.LONG, + "entry-short": Side.SHORT, + "exit": Side.FLAT, + "flat": Side.FLAT, +} + + +def _action_to_side(action: Node) -> Side: + return _ACTION_TO_SIDE[action.kind] + + +def compile_strategy(strategy: Strategy) -> Callable[[pd.DataFrame], pd.Series]: + """Compile a :class:`Strategy` AST into a ``df -> Series[Side]`` callable. + + Rules are evaluated in order; the first that matches wins for every + timestamp. Default to :data:`Side.FLAT` when no rule matches. + """ + + def fn(df: pd.DataFrame) -> pd.Series: + # ``object`` dtype lets us keep ``Side`` enum members alongside NaN + # sentinels for warmup rows (where indicators haven't filled yet). + result: pd.Series = pd.Series(np.nan, index=df.index, dtype=object) + already_set = pd.Series(False, index=df.index) + any_rule_seen = pd.Series(False, index=df.index) + for rule in strategy.rules: + match = _eval_node(rule.condition, df) + target = _action_to_side(rule.action) + valid = ~_isna_series(match) + any_rule_seen |= valid + match_bool = match.where(valid, False).astype(bool) + apply_mask = match_bool & ~already_set + result[apply_mask] = target + already_set |= apply_mask + # Rows where at least one rule was evaluable but none fired -> FLAT. + # Rows where every rule was NaN (full warmup) stay NaN. + flat_mask = any_rule_seen & ~already_set + result[flat_mask] = Side.FLAT + return result + + return fn + + +def _isna_series(s: pd.Series) -> pd.Series: + """``Series.isna`` with explicit bool cast to keep mypy happy.""" + return s.isna().astype(bool) diff --git a/tests/unit/test_protocol_compiler.py b/tests/unit/test_protocol_compiler.py new file mode 100644 index 0000000..c244e7d --- /dev/null +++ b/tests/unit/test_protocol_compiler.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +from multi_swarm.backtest.orders import Side +from multi_swarm.protocol.compiler import compile_strategy +from multi_swarm.protocol.parser import parse_strategy + + +@pytest.fixture +def ohlcv() -> pd.DataFrame: + idx = pd.date_range("2024-01-01", periods=200, freq="1h", tz="UTC") + close = np.linspace(100, 120, 200) + return pd.DataFrame( + { + "open": close, + "high": close + 0.5, + "low": close - 0.5, + "close": close, + "volume": 1.0, + }, + index=idx, + ) + + +def test_compile_simple_long(ohlcv: pd.DataFrame) -> None: + src = "(strategy (when (lt (indicator rsi 14) 100.0) (entry-long)))" + ast = parse_strategy(src) + fn = compile_strategy(ast) + signals = fn(ohlcv) + assert isinstance(signals, pd.Series) + assert (signals == Side.LONG).all() or (signals.dropna() == Side.LONG).all() + + +def test_compile_no_match_is_flat(ohlcv: pd.DataFrame) -> None: + src = "(strategy (when (gt (indicator rsi 14) 1000.0) (entry-long)))" + ast = parse_strategy(src) + fn = compile_strategy(ast) + signals = fn(ohlcv) + assert (signals == Side.FLAT).any() + + +def test_compile_two_rules_priority(ohlcv: pd.DataFrame) -> None: + src = """ + (strategy + (when (gt (feature close) 110.0) (entry-long)) + (when (lt (feature close) 105.0) (entry-short))) + """ + ast = parse_strategy(src) + fn = compile_strategy(ast) + signals = fn(ohlcv) + last = signals.iloc[-1] + assert last == Side.LONG # close finale e' 120, regola 1 matcha