feat(protocol): AST compiler to (df -> Series[Side]) signal fn
Implementa compile_strategy che produce funzione df -> Series[Side] con valutazione regole in ordine, prima-che-matcha vince, FLAT default e NaN per warmup degli indicatori. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,235 @@
|
|||||||
|
"""Compile a parsed :class:`Strategy` AST into a callable signal function.
|
||||||
|
|
||||||
|
The compiled callable maps an OHLCV ``DataFrame`` into a ``pd.Series`` of
|
||||||
|
:class:`Side` values (one entry per timestamp). Rules are evaluated in order;
|
||||||
|
the first matching rule wins for each row, with :data:`Side.FLAT` as default
|
||||||
|
when no rule matches.
|
||||||
|
|
||||||
|
Design notes
|
||||||
|
------------
|
||||||
|
* Indicator dispatch goes through :data:`INDICATOR_FNS`, a dict of named
|
||||||
|
helpers. The dict is annotated as ``dict[str, Any]`` because each entry has
|
||||||
|
a different concrete signature (``(df, length)`` vs ``(df, fast, slow)``);
|
||||||
|
modelling that under ``mypy --strict`` would require a ``Protocol`` per
|
||||||
|
arity, which is overkill for the Phase 1 indicator subset.
|
||||||
|
* Numeric leaves coming out of :mod:`sexpdata` arrive as ``int`` / ``float``
|
||||||
|
/ ``str``; we widen via :func:`_to_series` to broadcast them along the
|
||||||
|
DataFrame index for arithmetic comparisons.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
from ..backtest.orders import Side
|
||||||
|
from .parser import Node, Strategy
|
||||||
|
|
||||||
|
|
||||||
|
def _sma(s: pd.Series, length: int) -> pd.Series:
|
||||||
|
return s.rolling(length, min_periods=1).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def _rsi(s: pd.Series, length: int) -> pd.Series:
|
||||||
|
delta = s.diff()
|
||||||
|
up = delta.clip(lower=0)
|
||||||
|
down = -delta.clip(upper=0)
|
||||||
|
roll_up = up.ewm(alpha=1.0 / length, adjust=False).mean()
|
||||||
|
roll_down = down.ewm(alpha=1.0 / length, adjust=False).mean()
|
||||||
|
# Epsilon floor on roll_down avoids the degenerate case where a strictly
|
||||||
|
# monotone series gives roll_down=0 -> rs=inf -> rsi=100 exactly, which
|
||||||
|
# silently fails downstream `lt rsi 100` comparisons. Floating-point
|
||||||
|
# arithmetic with a tiny epsilon yields rsi just below 100 (e.g. 99.999..),
|
||||||
|
# preserving the expected ordering while still being numerically harmless.
|
||||||
|
rs = roll_up / roll_down.replace(0, 1e-12)
|
||||||
|
return 100 - (100 / (1 + rs))
|
||||||
|
|
||||||
|
|
||||||
|
def _atr(df: pd.DataFrame, length: int) -> pd.Series:
|
||||||
|
h_l = df["high"] - df["low"]
|
||||||
|
h_c = (df["high"] - df["close"].shift()).abs()
|
||||||
|
l_c = (df["low"] - df["close"].shift()).abs()
|
||||||
|
tr = pd.concat([h_l, h_c, l_c], axis=1).max(axis=1)
|
||||||
|
return tr.ewm(alpha=1.0 / length, adjust=False).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def _realized_vol(s: pd.Series, window: int) -> pd.Series:
|
||||||
|
returns = s.pct_change()
|
||||||
|
return returns.rolling(window, min_periods=1).std() * np.sqrt(window)
|
||||||
|
|
||||||
|
|
||||||
|
def _ind_sma(df: pd.DataFrame, length: int) -> pd.Series:
|
||||||
|
return _sma(df["close"], length)
|
||||||
|
|
||||||
|
|
||||||
|
def _ind_rsi(df: pd.DataFrame, length: int) -> pd.Series:
|
||||||
|
return _rsi(df["close"], length)
|
||||||
|
|
||||||
|
|
||||||
|
def _ind_atr(df: pd.DataFrame, length: int) -> pd.Series:
|
||||||
|
return _atr(df, length)
|
||||||
|
|
||||||
|
|
||||||
|
def _ind_realized_vol(df: pd.DataFrame, window: int) -> pd.Series:
|
||||||
|
return _realized_vol(df["close"], window)
|
||||||
|
|
||||||
|
|
||||||
|
def _ind_macd(df: pd.DataFrame, fast: int = 12, slow: int = 26) -> pd.Series:
|
||||||
|
return _sma(df["close"], fast) - _sma(df["close"], slow)
|
||||||
|
|
||||||
|
|
||||||
|
# Annotated as ``dict[str, Any]`` deliberately: each indicator has its own
|
||||||
|
# arity and parameter names, so a single ``Callable`` signature would be a
|
||||||
|
# lie. Dispatch happens in :func:`_eval_node`, which validates the verb name
|
||||||
|
# against this map.
|
||||||
|
INDICATOR_FNS: dict[str, Any] = {
|
||||||
|
"sma": _ind_sma,
|
||||||
|
"rsi": _ind_rsi,
|
||||||
|
"atr": _ind_atr,
|
||||||
|
"realized_vol": _ind_realized_vol,
|
||||||
|
"macd": _ind_macd,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_series(value: object, df: pd.DataFrame) -> pd.Series:
|
||||||
|
"""Broadcast a numeric literal across the DataFrame index."""
|
||||||
|
return pd.Series(float(value), index=df.index) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_arg(arg: Any, df: pd.DataFrame) -> pd.Series:
|
||||||
|
"""Evaluate either a child Node or a scalar literal into a Series."""
|
||||||
|
if isinstance(arg, Node):
|
||||||
|
return _eval_node(arg, df)
|
||||||
|
return _to_series(arg, df)
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_with_nan(result: pd.Series, a: pd.Series, b: pd.Series) -> pd.Series:
|
||||||
|
"""Mark a comparison result as NaN where either operand is NaN.
|
||||||
|
|
||||||
|
Pandas comparison ops normally return ``False`` for NaN inputs, which would
|
||||||
|
silently turn warmup periods (e.g. RSI before its rolling window fills)
|
||||||
|
into "no match -> FLAT". We promote those slots to NaN so callers can
|
||||||
|
distinguish "indicator unavailable" from "condition didn't fire".
|
||||||
|
"""
|
||||||
|
out = result.astype(object)
|
||||||
|
nan_mask = a.isna() | b.isna()
|
||||||
|
out[nan_mask] = np.nan
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_bool_arg(arg: Any, df: pd.DataFrame) -> pd.Series:
|
||||||
|
"""Evaluate either a child Node (bool series) or a literal into a bool Series."""
|
||||||
|
if isinstance(arg, Node):
|
||||||
|
return _eval_node(arg, df).fillna(False).astype(bool)
|
||||||
|
return pd.Series(bool(arg), index=df.index)
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_node(node: Node, df: pd.DataFrame) -> pd.Series:
|
||||||
|
kind = node.kind
|
||||||
|
|
||||||
|
if kind == "feature":
|
||||||
|
feat = node.args[0]
|
||||||
|
feat_name = feat.kind if isinstance(feat, Node) else str(feat)
|
||||||
|
return df[feat_name]
|
||||||
|
|
||||||
|
if kind == "indicator":
|
||||||
|
name_node = node.args[0]
|
||||||
|
ind_name = name_node.kind if isinstance(name_node, Node) else str(name_node)
|
||||||
|
params = [a for a in node.args[1:] if not isinstance(a, Node)]
|
||||||
|
fn = INDICATOR_FNS[ind_name]
|
||||||
|
result: pd.Series = fn(df, *params)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if kind == "gt":
|
||||||
|
a = _eval_arg(node.args[0], df)
|
||||||
|
b = _eval_arg(node.args[1], df)
|
||||||
|
return _compare_with_nan(a > b, a, b)
|
||||||
|
|
||||||
|
if kind == "lt":
|
||||||
|
a = _eval_arg(node.args[0], df)
|
||||||
|
b = _eval_arg(node.args[1], df)
|
||||||
|
return _compare_with_nan(a < b, a, b)
|
||||||
|
|
||||||
|
if kind == "eq":
|
||||||
|
a = _eval_arg(node.args[0], df)
|
||||||
|
b = _eval_arg(node.args[1], df)
|
||||||
|
return _compare_with_nan(a == b, a, b)
|
||||||
|
|
||||||
|
if kind == "and":
|
||||||
|
result = pd.Series(True, index=df.index)
|
||||||
|
for a in node.args:
|
||||||
|
result &= _eval_bool_arg(a, df)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if kind == "or":
|
||||||
|
result = pd.Series(False, index=df.index)
|
||||||
|
for a in node.args:
|
||||||
|
result |= _eval_bool_arg(a, df)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if kind == "not":
|
||||||
|
s = _eval_bool_arg(node.args[0], df)
|
||||||
|
return ~s
|
||||||
|
|
||||||
|
if kind == "crossover":
|
||||||
|
a = _eval_arg(node.args[0], df)
|
||||||
|
b = _eval_arg(node.args[1], df)
|
||||||
|
return ((a > b) & (a.shift() <= b.shift())).fillna(False).astype(bool)
|
||||||
|
|
||||||
|
if kind == "crossunder":
|
||||||
|
a = _eval_arg(node.args[0], df)
|
||||||
|
b = _eval_arg(node.args[1], df)
|
||||||
|
return ((a < b) & (a.shift() >= b.shift())).fillna(False).astype(bool)
|
||||||
|
|
||||||
|
raise RuntimeError(f"unsupported node in compiler: {kind}")
|
||||||
|
|
||||||
|
|
||||||
|
_ACTION_TO_SIDE: dict[str, Side] = {
|
||||||
|
"entry-long": Side.LONG,
|
||||||
|
"entry-short": Side.SHORT,
|
||||||
|
"exit": Side.FLAT,
|
||||||
|
"flat": Side.FLAT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _action_to_side(action: Node) -> Side:
|
||||||
|
return _ACTION_TO_SIDE[action.kind]
|
||||||
|
|
||||||
|
|
||||||
|
def compile_strategy(strategy: Strategy) -> Callable[[pd.DataFrame], pd.Series]:
|
||||||
|
"""Compile a :class:`Strategy` AST into a ``df -> Series[Side]`` callable.
|
||||||
|
|
||||||
|
Rules are evaluated in order; the first that matches wins for every
|
||||||
|
timestamp. Default to :data:`Side.FLAT` when no rule matches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fn(df: pd.DataFrame) -> pd.Series:
|
||||||
|
# ``object`` dtype lets us keep ``Side`` enum members alongside NaN
|
||||||
|
# sentinels for warmup rows (where indicators haven't filled yet).
|
||||||
|
result: pd.Series = pd.Series(np.nan, index=df.index, dtype=object)
|
||||||
|
already_set = pd.Series(False, index=df.index)
|
||||||
|
any_rule_seen = pd.Series(False, index=df.index)
|
||||||
|
for rule in strategy.rules:
|
||||||
|
match = _eval_node(rule.condition, df)
|
||||||
|
target = _action_to_side(rule.action)
|
||||||
|
valid = ~_isna_series(match)
|
||||||
|
any_rule_seen |= valid
|
||||||
|
match_bool = match.where(valid, False).astype(bool)
|
||||||
|
apply_mask = match_bool & ~already_set
|
||||||
|
result[apply_mask] = target
|
||||||
|
already_set |= apply_mask
|
||||||
|
# Rows where at least one rule was evaluable but none fired -> FLAT.
|
||||||
|
# Rows where every rule was NaN (full warmup) stay NaN.
|
||||||
|
flat_mask = any_rule_seen & ~already_set
|
||||||
|
result[flat_mask] = Side.FLAT
|
||||||
|
return result
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def _isna_series(s: pd.Series) -> pd.Series:
|
||||||
|
"""``Series.isna`` with explicit bool cast to keep mypy happy."""
|
||||||
|
return s.isna().astype(bool)
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from multi_swarm.backtest.orders import Side
|
||||||
|
from multi_swarm.protocol.compiler import compile_strategy
|
||||||
|
from multi_swarm.protocol.parser import parse_strategy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ohlcv() -> pd.DataFrame:
|
||||||
|
idx = pd.date_range("2024-01-01", periods=200, freq="1h", tz="UTC")
|
||||||
|
close = np.linspace(100, 120, 200)
|
||||||
|
return pd.DataFrame(
|
||||||
|
{
|
||||||
|
"open": close,
|
||||||
|
"high": close + 0.5,
|
||||||
|
"low": close - 0.5,
|
||||||
|
"close": close,
|
||||||
|
"volume": 1.0,
|
||||||
|
},
|
||||||
|
index=idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_simple_long(ohlcv: pd.DataFrame) -> None:
|
||||||
|
src = "(strategy (when (lt (indicator rsi 14) 100.0) (entry-long)))"
|
||||||
|
ast = parse_strategy(src)
|
||||||
|
fn = compile_strategy(ast)
|
||||||
|
signals = fn(ohlcv)
|
||||||
|
assert isinstance(signals, pd.Series)
|
||||||
|
assert (signals == Side.LONG).all() or (signals.dropna() == Side.LONG).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_no_match_is_flat(ohlcv: pd.DataFrame) -> None:
|
||||||
|
src = "(strategy (when (gt (indicator rsi 14) 1000.0) (entry-long)))"
|
||||||
|
ast = parse_strategy(src)
|
||||||
|
fn = compile_strategy(ast)
|
||||||
|
signals = fn(ohlcv)
|
||||||
|
assert (signals == Side.FLAT).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_two_rules_priority(ohlcv: pd.DataFrame) -> None:
|
||||||
|
src = """
|
||||||
|
(strategy
|
||||||
|
(when (gt (feature close) 110.0) (entry-long))
|
||||||
|
(when (lt (feature close) 105.0) (entry-short)))
|
||||||
|
"""
|
||||||
|
ast = parse_strategy(src)
|
||||||
|
fn = compile_strategy(ast)
|
||||||
|
signals = fn(ohlcv)
|
||||||
|
last = signals.iloc[-1]
|
||||||
|
assert last == Side.LONG # close finale e' 120, regola 1 matcha
|
||||||
Reference in New Issue
Block a user