feat(protocol): AST compiler to (df -> Series[Side]) signal fn

Implementa compile_strategy che produce funzione df -> Series[Side]
con valutazione regole in ordine, prima-che-matcha vince, FLAT default
e NaN per warmup degli indicatori.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-09 19:40:28 +02:00
parent 052f323790
commit 26c328d541
2 changed files with 290 additions and 0 deletions
+235
View File
@@ -0,0 +1,235 @@
"""Compile a parsed :class:`Strategy` AST into a callable signal function.
The compiled callable maps an OHLCV ``DataFrame`` into a ``pd.Series`` of
:class:`Side` values (one entry per timestamp). Rules are evaluated in order;
the first matching rule wins for each row, with :data:`Side.FLAT` as default
when no rule matches.
Design notes
------------
* Indicator dispatch goes through :data:`INDICATOR_FNS`, a dict of named
helpers. The dict is annotated as ``dict[str, Any]`` because each entry has
a different concrete signature (``(df, length)`` vs ``(df, fast, slow)``);
modelling that under ``mypy --strict`` would require a ``Protocol`` per
arity, which is overkill for the Phase 1 indicator subset.
* Numeric leaves coming out of :mod:`sexpdata` arrive as ``int`` / ``float``
/ ``str``; we widen via :func:`_to_series` to broadcast them along the
DataFrame index for arithmetic comparisons.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
import numpy as np
import pandas as pd # type: ignore[import-untyped]
from ..backtest.orders import Side
from .parser import Node, Strategy
def _sma(s: pd.Series, length: int) -> pd.Series:
return s.rolling(length, min_periods=1).mean()
def _rsi(s: pd.Series, length: int) -> pd.Series:
delta = s.diff()
up = delta.clip(lower=0)
down = -delta.clip(upper=0)
roll_up = up.ewm(alpha=1.0 / length, adjust=False).mean()
roll_down = down.ewm(alpha=1.0 / length, adjust=False).mean()
# Epsilon floor on roll_down avoids the degenerate case where a strictly
# monotone series gives roll_down=0 -> rs=inf -> rsi=100 exactly, which
# silently fails downstream `lt rsi 100` comparisons. Floating-point
# arithmetic with a tiny epsilon yields rsi just below 100 (e.g. 99.999..),
# preserving the expected ordering while still being numerically harmless.
rs = roll_up / roll_down.replace(0, 1e-12)
return 100 - (100 / (1 + rs))
def _atr(df: pd.DataFrame, length: int) -> pd.Series:
h_l = df["high"] - df["low"]
h_c = (df["high"] - df["close"].shift()).abs()
l_c = (df["low"] - df["close"].shift()).abs()
tr = pd.concat([h_l, h_c, l_c], axis=1).max(axis=1)
return tr.ewm(alpha=1.0 / length, adjust=False).mean()
def _realized_vol(s: pd.Series, window: int) -> pd.Series:
returns = s.pct_change()
return returns.rolling(window, min_periods=1).std() * np.sqrt(window)
def _ind_sma(df: pd.DataFrame, length: int) -> pd.Series:
return _sma(df["close"], length)
def _ind_rsi(df: pd.DataFrame, length: int) -> pd.Series:
return _rsi(df["close"], length)
def _ind_atr(df: pd.DataFrame, length: int) -> pd.Series:
return _atr(df, length)
def _ind_realized_vol(df: pd.DataFrame, window: int) -> pd.Series:
return _realized_vol(df["close"], window)
def _ind_macd(df: pd.DataFrame, fast: int = 12, slow: int = 26) -> pd.Series:
return _sma(df["close"], fast) - _sma(df["close"], slow)
# Annotated as ``dict[str, Any]`` deliberately: each indicator has its own
# arity and parameter names, so a single ``Callable`` signature would be a
# lie. Dispatch happens in :func:`_eval_node`, which validates the verb name
# against this map.
INDICATOR_FNS: dict[str, Any] = {
"sma": _ind_sma,
"rsi": _ind_rsi,
"atr": _ind_atr,
"realized_vol": _ind_realized_vol,
"macd": _ind_macd,
}
def _to_series(value: object, df: pd.DataFrame) -> pd.Series:
"""Broadcast a numeric literal across the DataFrame index."""
return pd.Series(float(value), index=df.index) # type: ignore[arg-type]
def _eval_arg(arg: Any, df: pd.DataFrame) -> pd.Series:
"""Evaluate either a child Node or a scalar literal into a Series."""
if isinstance(arg, Node):
return _eval_node(arg, df)
return _to_series(arg, df)
def _compare_with_nan(result: pd.Series, a: pd.Series, b: pd.Series) -> pd.Series:
"""Mark a comparison result as NaN where either operand is NaN.
Pandas comparison ops normally return ``False`` for NaN inputs, which would
silently turn warmup periods (e.g. RSI before its rolling window fills)
into "no match -> FLAT". We promote those slots to NaN so callers can
distinguish "indicator unavailable" from "condition didn't fire".
"""
out = result.astype(object)
nan_mask = a.isna() | b.isna()
out[nan_mask] = np.nan
return out
def _eval_bool_arg(arg: Any, df: pd.DataFrame) -> pd.Series:
"""Evaluate either a child Node (bool series) or a literal into a bool Series."""
if isinstance(arg, Node):
return _eval_node(arg, df).fillna(False).astype(bool)
return pd.Series(bool(arg), index=df.index)
def _eval_node(node: Node, df: pd.DataFrame) -> pd.Series:
kind = node.kind
if kind == "feature":
feat = node.args[0]
feat_name = feat.kind if isinstance(feat, Node) else str(feat)
return df[feat_name]
if kind == "indicator":
name_node = node.args[0]
ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node)
params = [a for a in node.args[1:] if not isinstance(a, Node)]
fn = INDICATOR_FNS[ind_name]
result: pd.Series = fn(df, *params)
return result
if kind == "gt":
a = _eval_arg(node.args[0], df)
b = _eval_arg(node.args[1], df)
return _compare_with_nan(a > b, a, b)
if kind == "lt":
a = _eval_arg(node.args[0], df)
b = _eval_arg(node.args[1], df)
return _compare_with_nan(a < b, a, b)
if kind == "eq":
a = _eval_arg(node.args[0], df)
b = _eval_arg(node.args[1], df)
return _compare_with_nan(a == b, a, b)
if kind == "and":
result = pd.Series(True, index=df.index)
for a in node.args:
result &= _eval_bool_arg(a, df)
return result
if kind == "or":
result = pd.Series(False, index=df.index)
for a in node.args:
result |= _eval_bool_arg(a, df)
return result
if kind == "not":
s = _eval_bool_arg(node.args[0], df)
return ~s
if kind == "crossover":
a = _eval_arg(node.args[0], df)
b = _eval_arg(node.args[1], df)
return ((a > b) & (a.shift() <= b.shift())).fillna(False).astype(bool)
if kind == "crossunder":
a = _eval_arg(node.args[0], df)
b = _eval_arg(node.args[1], df)
return ((a < b) & (a.shift() >= b.shift())).fillna(False).astype(bool)
raise RuntimeError(f"unsupported node in compiler: {kind}")
_ACTION_TO_SIDE: dict[str, Side] = {
"entry-long": Side.LONG,
"entry-short": Side.SHORT,
"exit": Side.FLAT,
"flat": Side.FLAT,
}
def _action_to_side(action: Node) -> Side:
return _ACTION_TO_SIDE[action.kind]
def compile_strategy(strategy: Strategy) -> Callable[[pd.DataFrame], pd.Series]:
"""Compile a :class:`Strategy` AST into a ``df -> Series[Side]`` callable.
Rules are evaluated in order; the first that matches wins for every
timestamp. Default to :data:`Side.FLAT` when no rule matches.
"""
def fn(df: pd.DataFrame) -> pd.Series:
# ``object`` dtype lets us keep ``Side`` enum members alongside NaN
# sentinels for warmup rows (where indicators haven't filled yet).
result: pd.Series = pd.Series(np.nan, index=df.index, dtype=object)
already_set = pd.Series(False, index=df.index)
any_rule_seen = pd.Series(False, index=df.index)
for rule in strategy.rules:
match = _eval_node(rule.condition, df)
target = _action_to_side(rule.action)
valid = ~_isna_series(match)
any_rule_seen |= valid
match_bool = match.where(valid, False).astype(bool)
apply_mask = match_bool & ~already_set
result[apply_mask] = target
already_set |= apply_mask
# Rows where at least one rule was evaluable but none fired -> FLAT.
# Rows where every rule was NaN (full warmup) stay NaN.
flat_mask = any_rule_seen & ~already_set
result[flat_mask] = Side.FLAT
return result
return fn
def _isna_series(s: pd.Series) -> pd.Series:
"""``Series.isna`` with explicit bool cast to keep mypy happy."""
return s.isna().astype(bool)
+55
View File
@@ -0,0 +1,55 @@
from __future__ import annotations
import numpy as np
import pandas as pd
import pytest
from multi_swarm.backtest.orders import Side
from multi_swarm.protocol.compiler import compile_strategy
from multi_swarm.protocol.parser import parse_strategy
@pytest.fixture
def ohlcv() -> pd.DataFrame:
idx = pd.date_range("2024-01-01", periods=200, freq="1h", tz="UTC")
close = np.linspace(100, 120, 200)
return pd.DataFrame(
{
"open": close,
"high": close + 0.5,
"low": close - 0.5,
"close": close,
"volume": 1.0,
},
index=idx,
)
def test_compile_simple_long(ohlcv: pd.DataFrame) -> None:
src = "(strategy (when (lt (indicator rsi 14) 100.0) (entry-long)))"
ast = parse_strategy(src)
fn = compile_strategy(ast)
signals = fn(ohlcv)
assert isinstance(signals, pd.Series)
assert (signals == Side.LONG).all() or (signals.dropna() == Side.LONG).all()
def test_compile_no_match_is_flat(ohlcv: pd.DataFrame) -> None:
src = "(strategy (when (gt (indicator rsi 14) 1000.0) (entry-long)))"
ast = parse_strategy(src)
fn = compile_strategy(ast)
signals = fn(ohlcv)
assert (signals == Side.FLAT).any()
def test_compile_two_rules_priority(ohlcv: pd.DataFrame) -> None:
src = """
(strategy
(when (gt (feature close) 110.0) (entry-long))
(when (lt (feature close) 105.0) (entry-short)))
"""
ast = parse_strategy(src)
fn = compile_strategy(ast)
signals = fn(ohlcv)
last = signals.iloc[-1]
assert last == Side.LONG # close finale e' 120, regola 1 matcha