Skip to content

Commit 9262fc0

Browse files
committed
refactor(LlamaBatch): replace set_batch with granular add_token + vectorized add_sequence
- Introduce high-performance add_token() for single-token append in generation loop - Add flexible add_sequence() with per-token pos/seq_ids/logits arrays - Remove old set_batch() that assumed single-seq + forced last logit - Better support for multi-sequence and precise logit control
1 parent 781790f commit 9262fc0

3 files changed

Lines changed: 82 additions & 25 deletions

File tree

llama_cpp/_internals.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -675,37 +675,82 @@ def reset(self):
675675
if self.batch is not None:
676676
self.batch.n_tokens = 0
677677

678-
def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_all: bool):
679-
if len(batch) > self.n_tokens_capacity:
680-
raise IndexError(f"Input batch size {len(batch)} exceeds capacity {self.n_tokens_capacity}")
678+
def add_token(self, token: int, pos: int, seq_ids: Sequence[int], logits: bool):
679+
"""
680+
Adds a single token to the batch.
681+
This is a high-performance method for appending a single token during the generation loop,
682+
avoiding the overhead of creating temporary lists required by add_sequence.
681683
682-
n_tokens = len(batch)
683-
self.batch.n_tokens = n_tokens
684-
for i in range(n_tokens):
685-
self.batch.token[i] = batch[i]
686-
self.batch.pos[i] = n_past + i
687-
self.batch.seq_id[i][0] = 0
688-
self.batch.n_seq_id[i] = 1
689-
self.batch.logits[i] = logits_all
690-
self.batch.logits[n_tokens - 1] = True
691-
692-
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
693-
n_tokens = len(batch)
684+
Args:
685+
token: The integer ID of the token to add.
686+
pos: The logical sequence position (n_past) of this token.
687+
seq_ids: A sequence of sequence IDs this token belongs to (e.g., [0] for a standard single chat).
688+
A single token can be part of multiple sequences simultaneously.
689+
logits: A boolean flag indicating whether the backend should compute logits for this token.
690+
"""
691+
idx = self.batch.n_tokens
692+
if idx >= self.n_tokens_capacity:
693+
raise IndexError(f"LlamaBatch overflow[add_token]: Cannot add token. Capacity {self.n_tokens_capacity} reached.")
694+
695+
self.batch.token[idx] = token
696+
self.batch.pos[idx] = pos
697+
698+
n_seq_id = len(seq_ids)
699+
if n_seq_id > self.n_seq_max:
700+
raise ValueError(f"LlamaBatch Error[add_token]: Token belongs to {n_seq_id} sequences, "
701+
f"but n_seq_max was initialized to {self.n_seq_max}.")
702+
self.batch.n_seq_id[idx] = n_seq_id
703+
704+
for i, seq_id in enumerate(seq_ids):
705+
self.batch.seq_id[idx][i] = seq_id
706+
self.batch.logits[idx] = logits
707+
708+
self.batch.n_tokens += 1
709+
710+
def add_sequence(
711+
self,
712+
token_array: Sequence[int],
713+
pos_array: Sequence[int],
714+
seq_ids: Sequence[Sequence[int]],
715+
logits_array: Sequence[bool]
716+
):
717+
"""
718+
Adds a sequence of tokens to the batch in a vectorized manner.
719+
Strictly maps the provided arrays to the underlying C++ batch structure without subjective overriding.
720+
721+
Args:
722+
token_array: A sequence of token IDs to be evaluated.
723+
pos_array: A sequence of logical positions corresponding to each token.
724+
seq_id_array: A sequence of lists, where each list contains the sequence IDs for the respective token.
725+
(e.g., [[0], [0], [0]] for 3 tokens belonging to sequence 0).
726+
logits_array: A sequence of boolean flags indicating whether to compute logits for each token.
727+
"""
728+
n_tokens = len(token_array)
694729
current_count = self.batch.n_tokens
730+
695731
if current_count + n_tokens > self.n_tokens_capacity:
696732
raise IndexError(
697-
f"LlamaBatch overflow: Cannot add {n_tokens} tokens. "
733+
f"LlamaBatch overflow[add_sequence]: Cannot add {n_tokens} tokens. "
698734
f"Space left: {self.n_tokens_capacity - current_count}"
699735
)
700-
self.batch.n_tokens += n_tokens
736+
737+
n_seq_id = len(seq_ids)
738+
if n_seq_id > self.n_seq_max:
739+
raise ValueError(f"LlamaBatch Error[add_sequence]: Token belongs to {n_seq_id} sequences, "
740+
f"but n_seq_max was initialized to {self.n_seq_max}.")
741+
701742
for i in range(n_tokens):
702743
j = current_count + i
703-
self.batch.token[j] = batch[i]
704-
self.batch.pos[j] = i
705-
self.batch.seq_id[j][0] = seq_id
706-
self.batch.n_seq_id[j] = 1
707-
self.batch.logits[j] = logits_all
708-
self.batch.logits[current_count + n_tokens - 1] = True
744+
self.batch.token[j] = token_array[i]
745+
self.batch.pos[j] = pos_array[i]
746+
747+
self.batch.n_seq_id[j] = n_seq_id
748+
for k, seq_id in enumerate(seq_ids):
749+
self.batch.seq_id[j][k] = seq_id
750+
751+
self.batch.logits[j] = logits_array[i]
752+
753+
self.batch.n_tokens += n_tokens
709754

710755

711756
# Embedding functions

llama_cpp/llama_embedding.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,20 @@ def _decode_batch():
251251
_decode_batch()
252252
idx_in_batch = 0
253253

254+
pos_array = list(range(n_tokens))
255+
256+
if is_none:
257+
logits_array = [True] * n_tokens
258+
else:
259+
logits_array = [False] * (n_tokens - 1) + [True]
260+
254261
# Add to Batch
255-
self._batch.add_sequence(tokens, idx_in_batch, logits_all=logits_all)
262+
self._batch.add_sequence(
263+
token_array=tokens,
264+
pos_array=pos_array,
265+
seq_ids=[idx_in_batch],
266+
logits_array=logits_array
267+
)
256268
batch_seq_lens.append(n_tokens)
257269
idx_in_batch += 1
258270

tests/test_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_real_model(llama_cpp_model_path):
124124

125125
for _ in range(4):
126126
# Prepare batch with current tokens
127-
batch.set_batch(curr_tokens, n_past=n_eval, logits_all=False)
127+
batch.add_token(curr_tokens, pos=n_eval, seq_ids=[0], logits=False)
128128

129129
# Decode (run inference)
130130
context.decode(batch)

0 commit comments

Comments
 (0)