Skip to content

Commit fc877d3

Browse files
authored
allow shared prefix question and system prompt variance and calculate… (#301)
This pull request introduces several enhancements and new features to the inference performance benchmarking and reporting framework. The main focus is on supporting Service Level Objective (SLO) tracking for latency metrics (TTFT and TPOT), making prompt and output length distributions more flexible, and improving metric calculation and reporting. The changes touch data models, configuration, data generation, metric collection, and reporting. **Key changes include:** ### SLO Tracking and Metric Enhancements * Added new fields to `RequestLifecycleMetric` (`ttft`, `tpot`, `ttft_slo`, `tpot_slo`, `ttft_slo_met`, `tpot_slo_met`, `ntpot`) to track time-to-first-token, time-per-output-token, their SLO thresholds, and attainment status. (`inference_perf/apis/base.py`) * Extended `APIConfig` to allow configuration of SLO thresholds and header names for TTFT and TPOT, and updated the OpenAI client to calculate these metrics and evaluate SLO attainment for each request. (`inference_perf/config.py`, `inference_perf/client/modelserver/openai_client.py`) [[1]](diffhunk://#diff-b20b7de6376037a1e80b0a93291951ae95cfa9893a3bf5fb2530c08a68304596R35-R37) [[2]](diffhunk://#diff-205d24014798b80a3f0ec5bca09dd11a20da8cf3edb8c6279aac366cc62f9313L203-R252) * Introduced a `calculate_slo_metrics` function to aggregate SLO attainment statistics and goodput, and integrated these metrics into the summary reporting. (`inference_perf/reportgen/base.py`) ### Flexible Prompt and Output Length Distribution * Added support for specifying standard deviation, min, and max for both question and output lengths in `SharedPrefix` config, and updated the data generator to use these parameters for more realistic prompt and output length distributions. (`inference_perf/config.py`, `inference_perf/datagen/shared_prefix_datagen.py`) * Ensured that prompt and user session shuffling is handled correctly to avoid ordering effects in data generation. (`inference_perf/datagen/shared_prefix_datagen.py`) ### Streaming API and Payload Improvements * Updated `to_payload` methods for chat and completion APIs to include `stream_options` when streaming, and fixed a parameter name for clarity in user session completion API data. (`inference_perf/apis/chat.py`, `inference_perf/apis/completion.py`, `inference_perf/apis/user_session.py`) ### Test Updates * Updated streaming API tests to account for the new `stream_options` field in the payload. (`tests/apis/test_completion.py`) Example added in the stage_x_lifecycle_metric.json: "slo_metrics": { "ttft_slo": { "attainment_pct": 83, "requests_met": 166, "requests_failed": 34, "total_requests": 200, "slo": 2 }, "tpot_slo": { "attainment_pct": 100, "requests_met": 200, "requests_failed": 0, "total_requests": 200, "slo": 0.2 }, "combined_slo": { "attainment_pct": 83, "requests_met": 166, "requests_failed": 34, "total_requests": 200, "ttft_slo": 2, "tpot_slo": 0.2, "goodput_rate": 23397.1484983487 } } },
1 parent e3e690b commit fc877d3

11 files changed

Lines changed: 537 additions & 79 deletions

File tree

docs/config.md

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,20 @@ This document provides complete documentation for all configuration options avai
2222

2323
### API Configuration
2424

25-
Controls the API interaction behavior:
25+
Controls the API interaction behavior. If SLO headers are present, each request is evaluated for SLO compliance and SLO-related metrics are reported:
2626

2727
```yaml
2828
api:
29-
type: completion # API type (completion|chat) (default: completion), completion is the default since the chat API is not typically enabled on model servers such as vLLM by default without additional configuration.
30-
streaming: false # Enable/disable streaming (default: false), needs to be enabled for metrics like TTFT, ITL and TPOT to be measured
31-
headers: # Add custom http headers to the request sent to the inference server
29+
type: completion # API type (completion|chat). completion is default since chat may require extra server config
30+
streaming: true # Enable streaming for TTFT, ITL, and TPOT metrics
31+
headers: # Optional custom HTTP headers
3232
x-inference-model: llama
3333
x-routing-strategy: round-robin
34+
x-slo-tpot-ms: "2"
35+
x-slo-ttft-ms: "1000"
36+
slo_unit: "ms" # Optional SLO unit (e.g., ms, s), default is ms
37+
slo_tpot_header: "x-slo-tpot-ms" # Optional header name for TPOT SLO Header, default is x-slo-tpot-ms
38+
slo_ttft_header: "x-slo-ttft-ms" # Optional header name for TTFT SLO Header, default is x-slo-ttft-ms
3439
```
3540
3641
### Data Generation
@@ -53,12 +58,22 @@ data:
5358
mean: 50
5459
std_dev: 10
5560
total_count: 100
56-
shared_prefix:
57-
num_unique_system_prompts: 10 # Number of distinct shared prefixes (formerly num_groups)
58-
num_users_per_system_prompt: 10 # Number of unique questions per shared prefix (formerly num_prompts_per_group)
59-
system_prompt_len: 100 # Length of the shared prefix (in tokens)
60-
question_len: 50 # Length of the unique question part (in tokens)
61-
output_len: 50 # Target length for the model's generated output (in tokens)
61+
shared_prefix: # For shared_prefix type
62+
num_groups: 10 # Number of shared prefix groups
63+
num_prompts_per_group: 10 # Unique questions per group
64+
system_prompt_len: 100 # Shared prefix length (tokens)
65+
question_len: 50 # Default question length (tokens), used when question_distribution is absent
66+
output_len: 50 # Default output length (tokens), used when output_distribution is absent
67+
question_distribution: # Optional: distribution for question lengths (overrides question_len)
68+
min: 10
69+
max: 1024
70+
mean: 50
71+
std_dev: 5
72+
output_distribution: # Optional: distribution for output lengths (overrides output_len)
73+
min: 10
74+
max: 1024
75+
mean: 50
76+
std_dev: 5
6277
```
6378
6479
### Load Configuration

inference_perf/apis/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class RequestLifecycleMetric(BaseModel):
4343
info: InferenceInfo
4444
error: Optional[ErrorResponseInfo]
4545

46+
ttft_slo_sec: Optional[float] = None
47+
tpot_slo_sec: Optional[float] = None
48+
49+
4650

4751
class InferenceAPIData(BaseModel):
4852
# loadgen should assign this request to prefered worker if possible

inference_perf/apis/chat.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ async def to_payload(
4444
if self.max_tokens == 0:
4545
self.max_tokens = max_tokens
4646
return {
47-
"model": effective_model_name,
48-
"messages": [{"role": m.role, "content": m.content} for m in self.messages],
49-
"max_tokens": self.max_tokens,
50-
"ignore_eos": ignore_eos,
51-
"stream": streaming,
52-
}
47+
"model": effective_model_name,
48+
"messages": [{"role": m.role, "content": m.content} for m in self.messages],
49+
"max_tokens": self.max_tokens,
50+
"ignore_eos": ignore_eos,
51+
"stream": streaming,
52+
**({"stream_options": {"include_usage": "true"}} if streaming else {}),
53+
}
54+
5355

5456
async def process_response(
5557
self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer, lora_adapter: Optional[str] = None

inference_perf/apis/completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ async def to_payload(
4545
"max_tokens": self.max_tokens,
4646
"ignore_eos": ignore_eos,
4747
"stream": streaming,
48+
**({"stream_options": {"include_usage": "true"}} if streaming else {}),
4849
}
4950

5051
async def process_response(

inference_perf/apis/user_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ class UserSessionCompletionAPIData(CompletionAPIData):
4747
user_session: LocalUserSession = Field(exclude=True)
4848
target_round: int
4949

50-
async def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
50+
async def to_payload(self, effective_model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
5151
self._session_context = await self.user_session.get_context(self.target_round)
5252
# TODO: Currently, only prompt style (concat messages) support. Adding support for messages style payload.
5353
self.prompt = self._session_context + " " + self.prompt
5454
# TODO: The combined prompt (session context + current prompt) might exceed the model's
5555
# maximum sequence length. Implement truncation logic/strategy to prevent
5656
# errors/failures from the inference server.
57-
return await super().to_payload(model_name, max_tokens, ignore_eos, streaming)
57+
return await super().to_payload(effective_model_name, max_tokens, ignore_eos, streaming)
5858

5959
def update_inference_info(self, inference_info: InferenceInfo) -> None:
6060
inference_info.extra_info["user_session"] = self.user_session.user_session_id

inference_perf/client/modelserver/openai_client.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
import ssl
2929

3030
logger = logging.getLogger(__name__)
31-
32-
31+
3332
class openAIModelServerClient(ModelServerClient):
3433
_session: "openAIModelServerClientSession | None" = None
3534
_session_lock = asyncio.Lock()
@@ -189,8 +188,7 @@ async def process_request(
189188
error_type=f"{response.status} {response.reason}",
190189
)
191190

192-
self.client.metrics_collector.record_metric(
193-
RequestLifecycleMetric(
191+
metric = RequestLifecycleMetric(
194192
stage_id=stage_id,
195193
request_data=request_data,
196194
response_data=response_content,
@@ -199,8 +197,34 @@ async def process_request(
199197
start_time=start,
200198
end_time=end_time,
201199
scheduled_time=scheduled_time,
202-
)
203-
)
200+
)
201+
202+
# Grab TTFT and TPOT thresholds from request headers if available for streaming requests with token-level timestamps
203+
if response_info.output_token_times:
204+
ttft_threshold = None
205+
tpot_threshold = None
206+
slo_unit = getattr(self.client.api_config, "slo_unit", None) or "ms"
207+
208+
default_ttft_header = f"x-slo-ttft-{slo_unit}"
209+
default_tpot_header = f"x-slo-tpot-{slo_unit}"
210+
ttft_header = getattr(self.client.api_config, "slo_ttft_header", None) or default_ttft_header
211+
tpot_header = getattr(self.client.api_config, "slo_tpot_header", None) or default_tpot_header
212+
if self.client.api_config.headers:
213+
ttft_threshold = self.client.api_config.headers.get(ttft_header)
214+
tpot_threshold = self.client.api_config.headers.get(tpot_header)
215+
216+
unit = slo_unit.lower()
217+
unit_to_s = {"s": 1.0, "ms": 0.001, "us": 0.000001}
218+
factor = unit_to_s.get(unit, 1.0)
219+
220+
if ttft_threshold is not None:
221+
metric.ttft_slo_sec = float(ttft_threshold) * factor
222+
223+
if tpot_threshold is not None:
224+
metric.tpot_slo_sec = float(tpot_threshold) * factor
225+
# Record the metric
226+
self.client.metrics_collector.record_metric(metric)
227+
204228
except Exception as e:
205229
if isinstance(e, asyncio.exceptions.TimeoutError):
206230
logger.error("request timed out:", exc_info=True)

inference_perf/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class APIConfig(BaseModel):
3232
type: APIType = APIType.Completion
3333
streaming: bool = False
3434
headers: Optional[dict[str, str]] = None
35+
slo_unit: Optional[str] = None
36+
slo_tpot_header: Optional[str] = None
37+
slo_ttft_header: Optional[str] = None
3538

3639

3740
class TraceFormat(Enum):
@@ -82,6 +85,8 @@ class SharedPrefix(BaseModel):
8285
system_prompt_len: int = 100
8386
question_len: int = 50
8487
output_len: int = 50
88+
question_distribution: Optional[Distribution] = None
89+
output_distribution: Optional[Distribution] = None
8590
enable_multi_turn_chat: bool = False
8691

8792

@@ -99,7 +104,6 @@ class DataConfig(BaseModel):
99104
# Trace file is only supported for random dataset at this moment
100105
trace: Optional[TraceConfig] = None
101106

102-
103107
class ModelServerType(Enum):
104108
VLLM = "vllm"
105109
SGLANG = "sglang"

inference_perf/datagen/shared_prefix_datagen.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import random
22
from typing import Generator, List, Optional
3+
from inference_perf.utils.distribution import generate_distribution
34
import numpy as np
45

56
from inference_perf.apis.base import InferenceAPIData, LazyLoadInferenceAPIData
67
from inference_perf.apis.completion import CompletionAPIData
78
from inference_perf.apis.user_session import LocalUserSession, UserSessionCompletionAPIData
8-
from inference_perf.config import APIConfig, APIType, DataConfig
9+
from inference_perf.config import APIConfig, APIType, DataConfig, Distribution
910
from inference_perf.utils.custom_tokenizer import CustomTokenizer
1011
from .base import DataGenerator, LazyLoadDataMixin
1112

@@ -42,12 +43,43 @@ def __init__(self, api_config: APIConfig, config: DataConfig, tokenizer: Optiona
4243
self.num_groups: int = self.shared_prefix.num_groups
4344
self.num_prompts_per_group: int = self.shared_prefix.num_prompts_per_group
4445
self.system_prompt_len: int = self.shared_prefix.system_prompt_len
45-
self.question_len: int = self.shared_prefix.question_len
46-
self.output_len: int = self.shared_prefix.output_len
4746
self.enable_multi_turn_chat: bool = self.shared_prefix.enable_multi_turn_chat
47+
48+
# Use distribution configs, or fall back to question_len/output_len with std_dev=0
49+
q_len = self.shared_prefix.question_len
50+
o_len = self.shared_prefix.output_len
51+
question_dist = self.shared_prefix.question_distribution or Distribution(min=q_len, max=q_len, mean=q_len, std_dev=0)
52+
output_dist = self.shared_prefix.output_distribution or Distribution(min=o_len, max=o_len, mean=o_len, std_dev=0)
53+
54+
# Generate separate distributions for each group
55+
self.question_len_list_per_group: List[List[int]] = []
56+
self.output_len_list_per_group: List[List[int]] = []
57+
58+
for _ in range(self.num_groups):
59+
question_lens = generate_distribution(
60+
question_dist.min,
61+
question_dist.max,
62+
question_dist.mean,
63+
question_dist.std_dev,
64+
self.shared_prefix.num_prompts_per_group,
65+
)
66+
self.question_len_list_per_group.append(question_lens.tolist())
67+
68+
output_lens = generate_distribution(
69+
output_dist.min,
70+
output_dist.max,
71+
output_dist.mean,
72+
output_dist.std_dev,
73+
self.shared_prefix.num_prompts_per_group,
74+
)
75+
self.output_len_list_per_group.append(output_lens.tolist())
76+
77+
78+
4879

4980
self.prompts: List[str] = []
5081
self.user_sessions: List[LocalUserSession] = []
82+
self.flat_output_lens: List[int] = []
5183
self._generate_prompts()
5284

5385
def get_supported_apis(self) -> List[APIType]:
@@ -64,17 +96,19 @@ def is_prefered_worker_requested(self) -> bool:
6496

6597
def load_lazy_data(self, data: LazyLoadInferenceAPIData) -> InferenceAPIData:
6698
i = data.data_index % len(self.prompts)
99+
output_len = self.flat_output_lens[i]
100+
67101
if self.enable_multi_turn_chat:
68102
user_id = data.data_index % len(self.user_sessions)
69103
round = data.data_index // len(self.user_sessions)
70104
return UserSessionCompletionAPIData(
71105
prompt=self.prompts[i],
72-
max_tokens=self.output_len,
106+
max_tokens=output_len,
73107
user_session=self.user_sessions[user_id],
74108
target_round=round,
75109
)
76110
else:
77-
return CompletionAPIData(prompt=self.prompts[i], max_tokens=self.output_len)
111+
return CompletionAPIData(prompt=self.prompts[i], max_tokens=output_len)
78112

79113
def get_data(self) -> Generator[InferenceAPIData, None, None]:
80114
if not self.prompts:
@@ -99,17 +133,27 @@ def _generate_prompts(self) -> None:
99133
# This check is defensive; __init__ should have already validated this.
100134
raise ValueError("Tokenizer is not available for generating prompts.")
101135

136+
if self.shared_prefix is None:
137+
raise ValueError("Shared prefix is not available for generating prompts.")
138+
102139
hf_tokenizer = self.tokenizer.get_tokenizer()
103140

104141
for group_id in range(self.num_groups):
105142
# Generate a shared prefix (system prompt)
106143
shared_prefix_token_ids = self._generate_random_token_ids(self.system_prompt_len)
107144
shared_prefix_text = hf_tokenizer.decode(shared_prefix_token_ids, skip_special_tokens=True)
108145

146+
# Batch generate all question token IDs for this group
147+
all_question_token_ids = [
148+
self._generate_random_token_ids(self.question_len_list_per_group[group_id][prompt_id])
149+
for prompt_id in range(self.num_prompts_per_group)
150+
]
151+
152+
# Batch decode all questions at once (much faster than individual decode calls)
153+
all_question_texts = hf_tokenizer.batch_decode(all_question_token_ids, skip_special_tokens=True)
154+
109155
for prompt_id in range(self.num_prompts_per_group):
110-
# Generate a unique question
111-
question_token_ids = self._generate_random_token_ids(self.question_len)
112-
question_text = hf_tokenizer.decode(question_token_ids, skip_special_tokens=True)
156+
question_text = all_question_texts[prompt_id]
113157

114158
if self.enable_multi_turn_chat:
115159
# multi turn chat, create user to keep conversation
@@ -125,9 +169,20 @@ def _generate_prompts(self) -> None:
125169

126170
self.prompts.append(question_text)
127171

172+
# Flatten output lengths to match prompts ordering
173+
self.flat_output_lens = [
174+
self.output_len_list_per_group[g][p]
175+
for g in range(self.num_groups)
176+
for p in range(self.num_prompts_per_group)
177+
]
178+
128179
# Shuffle the generated prompts to ensure randomness if served sequentially by different workers
129180
if self.enable_multi_turn_chat:
130181
# no need to sync shuffles - multi-round initial prompt does not include system prompt
131182
random.shuffle(self.user_sessions)
132183
else:
133-
random.shuffle(self.prompts)
184+
# Shuffle prompts and output lengths in sync
185+
combined = list(zip(self.prompts, self.flat_output_lens, strict=True))
186+
random.shuffle(combined)
187+
self.prompts, self.flat_output_lens = [list(t) for t in zip(*combined, strict=True)]
188+

0 commit comments

Comments
 (0)