fix(moa): count reference (advisor) fan-out token usage + cost (#56087)
MoA ran the reference models before the aggregator but returned only the aggregator's usage to the loop — _run_reference discarded each advisor response's .usage entirely. Session accounting (state.db, /insights, cost) therefore undercounted every MoA turn by the whole reference fan-out, which is usually the bulk of the spend and scales with advisor count. - _run_reference normalizes each advisor's usage with ITS OWN resolved provider/api_mode and prices it at ITS OWN model rate (correct cache-read/ cache-write split), returning a _RefAccounting(usage, cost). - create() sums advisor usage + cost once per turn (cache MISS only, so a repeat tool-iteration reusing cached advice does not double-charge) and exposes it via MoAClient.consume_reference_usage(). - conversation_loop folds advisor tokens into the reported/persisted token counts and adds advisor cost (priced per-advisor) on top of the aggregator cost, in both the in-memory session totals and the state.db per-call delta. Aggregator cost is still priced on aggregator-only usage so advisor tokens are never repriced at the aggregator rate. - CanonicalUsage gains __add__ for per-bucket summing. Tests: advisor usage/cost capture, per-turn sum + consume-clears + cache-hit no-double-charge, CanonicalUsage.__add__.
This commit is contained in:
parent
44ddc552f5
commit
3bdb23de10
4 changed files with 339 additions and 19 deletions
|
|
@ -1922,6 +1922,25 @@ def run_conversation(
|
|||
provider=agent.provider,
|
||||
api_mode=agent.api_mode,
|
||||
)
|
||||
# Aggregator-only usage is retained for cost pricing: MoA
|
||||
# advisor tokens must be priced at each advisor's OWN model
|
||||
# rate, not the aggregator's, so they are added as dollars
|
||||
# (below) rather than folded into the priced usage.
|
||||
aggregator_usage = canonical_usage
|
||||
# MoA: fold the reference (advisor) fan-out's token usage
|
||||
# into this turn's REPORTED token counts. MoA runs advisors
|
||||
# before the aggregator and returns only the aggregator's
|
||||
# usage, so without this the entire advisor spend — usually
|
||||
# the bulk of a MoA turn — is invisible in token counts.
|
||||
_moa_ref_cost = None
|
||||
_moa_client = getattr(agent, "client", None)
|
||||
if _moa_client is not None and hasattr(_moa_client, "consume_reference_usage"):
|
||||
try:
|
||||
_ref_usage, _moa_ref_cost = _moa_client.consume_reference_usage()
|
||||
if _ref_usage is not None:
|
||||
canonical_usage = canonical_usage + _ref_usage
|
||||
except Exception as _moa_acct_exc: # pragma: no cover - defensive
|
||||
logger.debug("MoA reference usage accounting failed: %s", _moa_acct_exc)
|
||||
prompt_tokens = canonical_usage.prompt_tokens
|
||||
completion_tokens = canonical_usage.output_tokens
|
||||
total_tokens = canonical_usage.total_tokens
|
||||
|
|
@ -1975,13 +1994,20 @@ def run_conversation(
|
|||
|
||||
cost_result = estimate_usage_cost(
|
||||
agent.model,
|
||||
canonical_usage,
|
||||
aggregator_usage,
|
||||
provider=agent.provider,
|
||||
base_url=agent.base_url,
|
||||
api_key=getattr(agent, "api_key", ""),
|
||||
)
|
||||
if cost_result.amount_usd is not None:
|
||||
agent.session_estimated_cost_usd += float(cost_result.amount_usd)
|
||||
# Add MoA advisor cost (already priced per-advisor at each
|
||||
# advisor's own model rate) on top of the aggregator cost.
|
||||
if _moa_ref_cost is not None:
|
||||
try:
|
||||
agent.session_estimated_cost_usd += float(_moa_ref_cost)
|
||||
except (TypeError, ValueError): # pragma: no cover - defensive
|
||||
pass
|
||||
agent.session_cost_status = cost_result.status
|
||||
agent.session_cost_source = cost_result.source
|
||||
|
||||
|
|
@ -2002,6 +2028,18 @@ def run_conversation(
|
|||
# affects 0 rows without error).
|
||||
if not agent._session_db_created:
|
||||
agent._ensure_db_session()
|
||||
# Per-call cost delta = aggregator cost + MoA
|
||||
# advisor cost (each priced at its own rate). Folded
|
||||
# here so state.db's estimated_cost_usd includes the
|
||||
# full MoA spend, matching the folded token counts.
|
||||
_cost_delta = None
|
||||
if cost_result.amount_usd is not None:
|
||||
_cost_delta = float(cost_result.amount_usd)
|
||||
if _moa_ref_cost is not None:
|
||||
try:
|
||||
_cost_delta = (_cost_delta or 0.0) + float(_moa_ref_cost)
|
||||
except (TypeError, ValueError): # pragma: no cover
|
||||
pass
|
||||
agent._session_db.update_token_counts(
|
||||
agent.session_id,
|
||||
input_tokens=canonical_usage.input_tokens,
|
||||
|
|
@ -2009,8 +2047,7 @@ def run_conversation(
|
|||
cache_read_tokens=canonical_usage.cache_read_tokens,
|
||||
cache_write_tokens=canonical_usage.cache_write_tokens,
|
||||
reasoning_tokens=canonical_usage.reasoning_tokens,
|
||||
estimated_cost_usd=float(cost_result.amount_usd)
|
||||
if cost_result.amount_usd is not None else None,
|
||||
estimated_cost_usd=_cost_delta,
|
||||
cost_status=cost_result.status,
|
||||
cost_source=cost_result.source,
|
||||
billing_provider=agent.provider,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,27 @@ logger = logging.getLogger(__name__)
|
|||
# opening dozens of sockets at once.
|
||||
_MAX_REFERENCE_WORKERS = 8
|
||||
|
||||
|
||||
class _RefAccounting:
|
||||
"""Per-reference token usage + estimated cost, carried as the third slot
|
||||
of a reference-output tuple.
|
||||
|
||||
Kept as a tiny object (not a bare CanonicalUsage) because an advisor may
|
||||
run on a different model/provider than the aggregator, so its cost MUST be
|
||||
priced at its OWN model's rate — folding advisor tokens into the
|
||||
aggregator's usage and pricing the sum at the aggregator's rate would
|
||||
misprice every advisor. ``usage`` feeds accurate token counts;
|
||||
``cost_usd`` feeds accurate cost.
|
||||
"""
|
||||
|
||||
__slots__ = ("usage", "cost_usd", "cost_status", "cost_source")
|
||||
|
||||
def __init__(self, usage: Any, cost_usd: Any = None, cost_status: str | None = None, cost_source: str | None = None):
|
||||
self.usage = usage
|
||||
self.cost_usd = cost_usd
|
||||
self.cost_status = cost_status
|
||||
self.cost_source = cost_source
|
||||
|
||||
# Per-tool-result character budget for the advisory reference view. Tool
|
||||
# results can be huge (a full diff, a 5000-line file dump); replaying them
|
||||
# verbatim per reference per tool-loop step would blow the reference model's
|
||||
|
|
@ -125,8 +146,8 @@ def _run_reference(
|
|||
*,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Call one reference model and return ``(label, text)``.
|
||||
) -> tuple[str, str, Any]:
|
||||
"""Call one reference model and return ``(label, text, usage)``.
|
||||
|
||||
The slot is resolved to its provider's real runtime (via ``_slot_runtime``)
|
||||
and called through the same ``call_llm`` request-building path any model
|
||||
|
|
@ -137,12 +158,23 @@ def _run_reference(
|
|||
real maximum); ``temperature`` is only the user's configured preset value,
|
||||
which call_llm may still override per model.
|
||||
|
||||
The reference's token usage is normalized with the slot's OWN resolved
|
||||
provider/api_mode (advisors may run on a different provider than the
|
||||
aggregator, with different usage wire shapes) and returned as a
|
||||
``CanonicalUsage`` so the caller can fold advisor spend into session
|
||||
accounting. Without this, the entire reference fan-out — often the bulk of
|
||||
a MoA turn's token spend — is invisible to cost tracking, which only ever
|
||||
saw the aggregator's usage.
|
||||
|
||||
Never raises: a failed reference becomes a labelled note so the aggregator
|
||||
can still act with partial context. Designed to run inside a thread pool —
|
||||
``call_llm`` is synchronous/blocking, so threads (not asyncio) are the right
|
||||
concurrency primitive, mirroring ``delegate_task``'s batch fan-out.
|
||||
"""
|
||||
from agent.usage_pricing import CanonicalUsage, estimate_usage_cost, normalize_usage
|
||||
|
||||
label = _slot_label(slot)
|
||||
runtime = _slot_runtime(slot)
|
||||
try:
|
||||
# Prepend the advisory-role system prompt so the reference understands
|
||||
# it is analyzing state for an aggregator, not acting on the task. The
|
||||
|
|
@ -154,12 +186,44 @@ def _run_reference(
|
|||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**_slot_runtime(slot),
|
||||
**runtime,
|
||||
)
|
||||
return label, _extract_text(response) or "(empty response)"
|
||||
usage = CanonicalUsage()
|
||||
raw_usage = getattr(response, "usage", None)
|
||||
if raw_usage:
|
||||
try:
|
||||
usage = normalize_usage(
|
||||
raw_usage,
|
||||
provider=runtime.get("provider"),
|
||||
api_mode=runtime.get("api_mode"),
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
usage = CanonicalUsage()
|
||||
# Price this advisor at ITS OWN model/provider rate (with correct
|
||||
# cache-read/cache-write split), not the aggregator's. This is why
|
||||
# advisor cost is summed as dollars rather than by folding tokens into
|
||||
# the aggregator's usage.
|
||||
cost_usd = None
|
||||
cost_status = None
|
||||
cost_source = None
|
||||
try:
|
||||
cost = estimate_usage_cost(
|
||||
slot.get("model") or "",
|
||||
usage,
|
||||
provider=runtime.get("provider"),
|
||||
base_url=runtime.get("base_url"),
|
||||
api_key=runtime.get("api_key"),
|
||||
)
|
||||
cost_usd = cost.amount_usd
|
||||
cost_status = cost.status
|
||||
cost_source = cost.source
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
acct = _RefAccounting(usage, cost_usd, cost_status, cost_source)
|
||||
return label, _extract_text(response) or "(empty response)", acct
|
||||
except Exception as exc:
|
||||
logger.warning("MoA reference model %s failed: %s", label, exc)
|
||||
return label, f"[failed: {exc}]"
|
||||
return label, f"[failed: {exc}]", _RefAccounting(CanonicalUsage())
|
||||
|
||||
|
||||
def _run_references_parallel(
|
||||
|
|
@ -168,7 +232,7 @@ def _run_references_parallel(
|
|||
*,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
) -> list[tuple[str, str, Any]]:
|
||||
"""Fan out all reference models in parallel, returning outputs in order.
|
||||
|
||||
Like ``delegate_task``'s batch mode, every reference is dispatched at once
|
||||
|
|
@ -176,11 +240,16 @@ def _run_references_parallel(
|
|||
the aggregator. Output order matches ``reference_models`` so the
|
||||
``Reference {idx}`` labelling stays stable. MoA presets that reference
|
||||
another MoA preset are skipped here (recursion guard) with a labelled note.
|
||||
|
||||
Each element is ``(label, text, usage)`` where usage is a
|
||||
``CanonicalUsage`` (zeroed for skipped/failed references).
|
||||
"""
|
||||
from agent.usage_pricing import CanonicalUsage
|
||||
|
||||
if not reference_models:
|
||||
return []
|
||||
|
||||
results: list[tuple[str, str] | None] = [None] * len(reference_models)
|
||||
results: list[tuple[str, str, Any] | None] = [None] * len(reference_models)
|
||||
futures = {}
|
||||
workers = min(_MAX_REFERENCE_WORKERS, len(reference_models))
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
|
|
@ -189,6 +258,7 @@ def _run_references_parallel(
|
|||
results[idx] = (
|
||||
_slot_label(slot),
|
||||
"[skipped: MoA presets cannot recursively reference MoA]",
|
||||
_RefAccounting(CanonicalUsage()),
|
||||
)
|
||||
continue
|
||||
futures[
|
||||
|
|
@ -390,7 +460,7 @@ def aggregate_moa_context(
|
|||
sidesteps providers that reject ``max_tokens`` outright. A hardcoded cap
|
||||
here previously truncated long aggregator syntheses.
|
||||
"""
|
||||
reference_outputs: list[tuple[str, str]] = []
|
||||
reference_outputs: list[tuple[str, str, Any]] = []
|
||||
ref_messages = _reference_messages(api_messages)
|
||||
reference_outputs = _run_references_parallel(
|
||||
reference_models,
|
||||
|
|
@ -401,7 +471,7 @@ def aggregate_moa_context(
|
|||
|
||||
joined = "\n\n".join(
|
||||
f"Reference {idx} — {label}:\n{text}"
|
||||
for idx, (label, text) in enumerate(reference_outputs, start=1)
|
||||
for idx, (label, text, _usage) in enumerate(reference_outputs, start=1)
|
||||
)
|
||||
synth_prompt = (
|
||||
"You are the aggregator in a Mixture of Agents process. Synthesize the "
|
||||
|
|
@ -465,7 +535,33 @@ class MoAChatCompletions:
|
|||
# re-run, no re-emit). This gives "fire on every user/tool response"
|
||||
# for free, without re-firing on a pure no-op re-call.
|
||||
self._ref_cache_key: tuple | None = None
|
||||
self._ref_cache_outputs: list[tuple[str, str]] = []
|
||||
self._ref_cache_outputs: list[tuple[str, str, Any]] = []
|
||||
# Token usage + estimated cost of the reference fan-out from the most
|
||||
# recent cache-MISS create() call, awaiting consumption by session
|
||||
# accounting. Set on every create() (zeroed on a cache HIT so per-turn
|
||||
# advisor spend is counted exactly once). Consumed via
|
||||
# ``consume_reference_usage``.
|
||||
from agent.usage_pricing import CanonicalUsage
|
||||
|
||||
self._pending_reference_usage: Any = CanonicalUsage()
|
||||
self._pending_reference_cost: Any = None
|
||||
|
||||
def consume_reference_usage(self) -> tuple[Any, Any]:
|
||||
"""Pop pending reference-fan-out usage + cost, resetting both to empty.
|
||||
|
||||
Returns ``(CanonicalUsage, cost_usd_or_None)`` for the most recent
|
||||
``create()`` and clears the pending values, so a subsequent read (e.g.
|
||||
a streaming retry re-entering accounting) cannot double-count. Usage is
|
||||
always a ``CanonicalUsage`` (zeroed if none); cost is a summed-dollars
|
||||
float or ``None`` when no advisor could be priced.
|
||||
"""
|
||||
from agent.usage_pricing import CanonicalUsage
|
||||
|
||||
usage = self._pending_reference_usage or CanonicalUsage()
|
||||
cost = self._pending_reference_cost
|
||||
self._pending_reference_usage = CanonicalUsage()
|
||||
self._pending_reference_cost = None
|
||||
return usage, cost
|
||||
|
||||
def _emit(self, event: str, **kwargs: Any) -> None:
|
||||
cb = self.reference_callback
|
||||
|
|
@ -497,7 +593,9 @@ class MoAChatCompletions:
|
|||
if not preset.get("enabled", True):
|
||||
reference_models = []
|
||||
|
||||
reference_outputs: list[tuple[str, str]] = []
|
||||
from agent.usage_pricing import CanonicalUsage
|
||||
|
||||
reference_outputs: list[tuple[str, str, Any]] = []
|
||||
ref_messages = _reference_messages(messages)
|
||||
|
||||
# Turn-scoped cache: only run + display references when the advisory
|
||||
|
|
@ -514,6 +612,12 @@ class MoAChatCompletions:
|
|||
|
||||
if _refs_from_cache:
|
||||
reference_outputs = list(self._ref_cache_outputs)
|
||||
# References already ran (and were accounted) earlier this turn;
|
||||
# this create() is a repeat tool-iteration reusing the cached
|
||||
# advice. Charging their tokens/cost again here would multiply
|
||||
# advisor spend by the tool-iteration count, so pending is zero.
|
||||
self._pending_reference_usage = CanonicalUsage()
|
||||
self._pending_reference_cost = None
|
||||
else:
|
||||
reference_outputs = _run_references_parallel(
|
||||
reference_models,
|
||||
|
|
@ -523,6 +627,24 @@ class MoAChatCompletions:
|
|||
)
|
||||
self._ref_cache_key = _cache_key
|
||||
self._ref_cache_outputs = list(reference_outputs)
|
||||
# Sum the advisor fan-out's token usage AND cost so the caller can
|
||||
# fold advisor spend into session accounting exactly once per turn.
|
||||
# Only the freshly run references (cache MISS) contribute; a cache
|
||||
# HIT above zeroes this. Token counts sum directly (each already
|
||||
# normalized per-advisor provider/api_mode); cost sums in dollars
|
||||
# because each advisor was priced at its OWN model rate — advisors
|
||||
# may be cheaper/pricier than the aggregator, so their tokens must
|
||||
# NOT be repriced at the aggregator's rate.
|
||||
_ref_usage = CanonicalUsage()
|
||||
_ref_cost: Any = None
|
||||
for _lbl, _txt, _acct in reference_outputs:
|
||||
if isinstance(_acct, _RefAccounting):
|
||||
if isinstance(_acct.usage, CanonicalUsage):
|
||||
_ref_usage = _ref_usage + _acct.usage
|
||||
if _acct.cost_usd is not None:
|
||||
_ref_cost = (_ref_cost or 0) + _acct.cost_usd
|
||||
self._pending_reference_usage = _ref_usage
|
||||
self._pending_reference_cost = _ref_cost
|
||||
|
||||
# Surface each reference model's answer to the display BEFORE the
|
||||
# aggregator acts — once per turn (only on the iteration that
|
||||
|
|
@ -531,7 +653,7 @@ class MoAChatCompletions:
|
|||
# visible rather than a silent pause. Best-effort: never blocks the
|
||||
# turn.
|
||||
_ref_count = len(reference_outputs)
|
||||
for _idx, (_label, _text) in enumerate(reference_outputs, start=1):
|
||||
for _idx, (_label, _text, _usage) in enumerate(reference_outputs, start=1):
|
||||
self._emit(
|
||||
"moa.reference",
|
||||
index=_idx,
|
||||
|
|
@ -550,13 +672,13 @@ class MoAChatCompletions:
|
|||
if reference_outputs:
|
||||
joined = "\n\n".join(
|
||||
f"Reference {idx} — {label}:\n{text}"
|
||||
for idx, (label, text) in enumerate(reference_outputs, start=1)
|
||||
for idx, (label, text, _usage) in enumerate(reference_outputs, start=1)
|
||||
)
|
||||
guidance = (
|
||||
"[Mixture of Agents reference context]\n"
|
||||
f"Preset: {self.preset_name}\n"
|
||||
f"Aggregator/acting model: {_slot_label(aggregator)}\n"
|
||||
f"References: {', '.join(label for label, _ in reference_outputs)}\n\n"
|
||||
f"References: {', '.join(label for label, _, _ in reference_outputs)}\n\n"
|
||||
"Use the reference responses below as private context. You are the aggregator and acting model: "
|
||||
"answer the user directly or call tools as needed.\n\n"
|
||||
f"{joined}"
|
||||
|
|
@ -614,3 +736,11 @@ class MoAClient:
|
|||
def __init__(self, preset_name: str, reference_callback: Any = None):
|
||||
self.chat = type("_MoAChat", (), {})()
|
||||
self.chat.completions = MoAChatCompletions(preset_name, reference_callback=reference_callback)
|
||||
|
||||
def consume_reference_usage(self) -> Any:
|
||||
"""Pop the pending reference-fan-out usage from the completions facade.
|
||||
|
||||
Lets session accounting fold the MoA advisor tokens into the turn's
|
||||
usage without reaching into ``.chat.completions`` internals.
|
||||
"""
|
||||
return self.chat.completions.consume_reference_usage()
|
||||
|
|
|
|||
|
|
@ -45,6 +45,25 @@ class CanonicalUsage:
|
|||
def total_tokens(self) -> int:
|
||||
return self.prompt_tokens + self.output_tokens
|
||||
|
||||
def __add__(self, other: "CanonicalUsage") -> "CanonicalUsage":
|
||||
"""Sum two usage buckets (e.g. MoA advisor fan-out + aggregator).
|
||||
|
||||
``raw_usage`` is dropped on the sum — it describes a single API
|
||||
response and cannot be meaningfully merged. ``request_count`` adds so
|
||||
callers can see how many underlying API calls a combined figure covers.
|
||||
"""
|
||||
if not isinstance(other, CanonicalUsage):
|
||||
return NotImplemented
|
||||
return CanonicalUsage(
|
||||
input_tokens=self.input_tokens + other.input_tokens,
|
||||
output_tokens=self.output_tokens + other.output_tokens,
|
||||
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
||||
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
||||
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
||||
request_count=self.request_count + other.request_count,
|
||||
raw_usage=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BillingRoute:
|
||||
|
|
|
|||
|
|
@ -410,7 +410,7 @@ def test_run_reference_prepends_advisory_system_prompt(monkeypatch):
|
|||
|
||||
monkeypatch.setattr("agent.moa_loop.call_llm", fake_call_llm)
|
||||
|
||||
label, text = _run_reference(
|
||||
label, text, _acct = _run_reference(
|
||||
{"provider": "openai-codex", "model": "gpt-5.5"},
|
||||
[{"role": "user", "content": "review this PR"}],
|
||||
)
|
||||
|
|
@ -568,7 +568,7 @@ def test_references_run_in_parallel(monkeypatch):
|
|||
# Two 0.5s sleeps run concurrently → well under the 1.0s serial floor.
|
||||
assert elapsed < 0.9, f"references did not run in parallel (took {elapsed:.2f}s)"
|
||||
# Output order matches input order (stable Reference N labelling).
|
||||
assert [label for label, _ in out] == ["p1:ok", "moa:preset", "p2:boom", "p3:ok"]
|
||||
assert [label for label, _, _ in out] == ["p1:ok", "moa:preset", "p2:boom", "p3:ok"]
|
||||
assert "recursively reference MoA" in out[1][1]
|
||||
assert out[2][1].startswith("[failed:")
|
||||
assert out[0][1] == "resp-p1"
|
||||
|
|
@ -750,3 +750,137 @@ def test_slot_runtime_anthropic_oauth_routes_through_provider_branch(monkeypatch
|
|||
assert other_rt["model"] == "some-model"
|
||||
assert other_rt["base_url"] == "https://resolved.example/v1"
|
||||
assert other_rt["api_key"] == "resolved-key"
|
||||
|
||||
|
||||
def _response_with_usage(content="advice", *, prompt=100, completion=50, cached=0):
|
||||
"""A fake response carrying OpenAI-style usage so normalize_usage works."""
|
||||
details = SimpleNamespace(cached_tokens=cached, cache_write_tokens=0)
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=prompt,
|
||||
completion_tokens=completion,
|
||||
prompt_tokens_details=details,
|
||||
output_tokens_details=None,
|
||||
)
|
||||
message = SimpleNamespace(content=content, tool_calls=[])
|
||||
choice = SimpleNamespace(message=message, finish_reason="stop")
|
||||
return SimpleNamespace(choices=[choice], usage=usage, model="fake-model")
|
||||
|
||||
|
||||
def test_run_reference_captures_usage_and_cost(monkeypatch):
|
||||
"""A reference call returns per-advisor CanonicalUsage + priced cost.
|
||||
|
||||
Before this, _run_reference discarded response.usage entirely, so the
|
||||
advisor fan-out was invisible to cost tracking.
|
||||
"""
|
||||
from agent.moa_loop import _RefAccounting, _run_reference
|
||||
from agent.usage_pricing import CanonicalUsage
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.moa_loop.call_llm",
|
||||
lambda **kw: _response_with_usage(prompt=1000, completion=200, cached=400),
|
||||
)
|
||||
# Keep runtime resolution + pricing deterministic.
|
||||
monkeypatch.setattr(
|
||||
"agent.moa_loop._slot_runtime",
|
||||
lambda slot: {"provider": "openrouter", "model": slot.get("model")},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.estimate_usage_cost",
|
||||
lambda *a, **k: SimpleNamespace(amount_usd=0.0123, status="estimated", source="table"),
|
||||
)
|
||||
|
||||
label, text, acct = _run_reference(
|
||||
{"provider": "openrouter", "model": "vendor/adv-model"},
|
||||
[{"role": "user", "content": "state?"}],
|
||||
)
|
||||
|
||||
assert text == "advice"
|
||||
assert isinstance(acct, _RefAccounting)
|
||||
assert isinstance(acct.usage, CanonicalUsage)
|
||||
# prompt_tokens=1000 with 400 cached → 600 fresh input + 400 cache_read.
|
||||
assert acct.usage.input_tokens == 600
|
||||
assert acct.usage.cache_read_tokens == 400
|
||||
assert acct.usage.output_tokens == 200
|
||||
assert acct.cost_usd == 0.0123
|
||||
|
||||
|
||||
def test_references_parallel_sum_and_consume(monkeypatch, tmp_path):
|
||||
"""create() sums advisor usage + cost once per turn; consume clears it.
|
||||
|
||||
Repeat tool-iterations within a turn reuse the cache and contribute ZERO
|
||||
additional advisor spend (otherwise advisor cost multiplies by iteration
|
||||
count).
|
||||
"""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
(home / "config.yaml").write_text(
|
||||
"""
|
||||
moa:
|
||||
default_preset: review
|
||||
presets:
|
||||
review:
|
||||
reference_models:
|
||||
- provider: openrouter
|
||||
model: adv-a
|
||||
- provider: openrouter
|
||||
model: adv-b
|
||||
aggregator:
|
||||
provider: openrouter
|
||||
model: anthropic/claude-opus-4.8
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
def fake_call_llm(**kwargs):
|
||||
if kwargs["task"] == "moa_reference":
|
||||
return _response_with_usage(prompt=1000, completion=100, cached=0)
|
||||
return _response("aggregator acted")
|
||||
|
||||
monkeypatch.setattr("agent.moa_loop.call_llm", fake_call_llm)
|
||||
monkeypatch.setattr(
|
||||
"agent.moa_loop._slot_runtime",
|
||||
lambda slot: {"provider": "openrouter", "model": slot.get("model")},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.estimate_usage_cost",
|
||||
lambda *a, **k: SimpleNamespace(amount_usd=0.01, status="estimated", source="table"),
|
||||
)
|
||||
|
||||
from agent.moa_loop import MoAChatCompletions
|
||||
|
||||
facade = MoAChatCompletions("review")
|
||||
facade.create(messages=[{"role": "user", "content": "turn one"}], tools=[])
|
||||
|
||||
usage, cost = facade.consume_reference_usage()
|
||||
# Two advisors × (1000 input, 100 output) = 2000 input, 200 output.
|
||||
assert usage.input_tokens == 2000
|
||||
assert usage.output_tokens == 200
|
||||
# Two advisors × $0.01 each = $0.02.
|
||||
assert cost == pytest.approx(0.02)
|
||||
|
||||
# consume clears — a second consume with no new create() is zeroed.
|
||||
usage2, cost2 = facade.consume_reference_usage()
|
||||
assert usage2.input_tokens == 0
|
||||
assert cost2 is None
|
||||
|
||||
# A repeat create() with the SAME advisory view is a cache HIT: advisors
|
||||
# do not re-run, so pending advisor spend is zero (no double-charge).
|
||||
facade.create(messages=[{"role": "user", "content": "turn one"}], tools=[])
|
||||
usage3, cost3 = facade.consume_reference_usage()
|
||||
assert usage3.input_tokens == 0
|
||||
assert cost3 is None
|
||||
|
||||
|
||||
def test_canonical_usage_add():
|
||||
"""CanonicalUsage sums per bucket (used to fold advisor tokens in)."""
|
||||
from agent.usage_pricing import CanonicalUsage
|
||||
|
||||
a = CanonicalUsage(input_tokens=100, output_tokens=20, cache_read_tokens=5)
|
||||
b = CanonicalUsage(input_tokens=50, output_tokens=10, cache_write_tokens=3)
|
||||
total = a + b
|
||||
assert total.input_tokens == 150
|
||||
assert total.output_tokens == 30
|
||||
assert total.cache_read_tokens == 5
|
||||
assert total.cache_write_tokens == 3
|
||||
assert total.request_count == 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue