@@ -675,37 +675,82 @@ def reset(self):
675675 if self .batch is not None :
676676 self .batch .n_tokens = 0
677677
678- def set_batch (self , batch : Sequence [int ], n_past : llama_cpp .llama_pos , logits_all : bool ):
679- if len (batch ) > self .n_tokens_capacity :
680- raise IndexError (f"Input batch size { len (batch )} exceeds capacity { self .n_tokens_capacity } " )
678+ def add_token (self , token : int , pos : int , seq_ids : Sequence [int ], logits : bool ):
679+ """
680+ Adds a single token to the batch.
681+ This is a high-performance method for appending a single token during the generation loop,
682+ avoiding the overhead of creating temporary lists required by add_sequence.
681683
682- n_tokens = len (batch )
683- self .batch .n_tokens = n_tokens
684- for i in range (n_tokens ):
685- self .batch .token [i ] = batch [i ]
686- self .batch .pos [i ] = n_past + i
687- self .batch .seq_id [i ][0 ] = 0
688- self .batch .n_seq_id [i ] = 1
689- self .batch .logits [i ] = logits_all
690- self .batch .logits [n_tokens - 1 ] = True
691-
692- def add_sequence (self , batch : Sequence [int ], seq_id : int , logits_all : bool ):
693- n_tokens = len (batch )
684+ Args:
685+ token: The integer ID of the token to add.
686+ pos: The logical sequence position (n_past) of this token.
687+ seq_ids: A sequence of sequence IDs this token belongs to (e.g., [0] for a standard single chat).
688+ A single token can be part of multiple sequences simultaneously.
689+ logits: A boolean flag indicating whether the backend should compute logits for this token.
690+ """
691+ idx = self .batch .n_tokens
692+ if idx >= self .n_tokens_capacity :
693+ raise IndexError (f"LlamaBatch overflow[add_token]: Cannot add token. Capacity { self .n_tokens_capacity } reached." )
694+
695+ self .batch .token [idx ] = token
696+ self .batch .pos [idx ] = pos
697+
698+ n_seq_id = len (seq_ids )
699+ if n_seq_id > self .n_seq_max :
700+ raise ValueError (f"LlamaBatch Error[add_token]: Token belongs to { n_seq_id } sequences, "
701+ f"but n_seq_max was initialized to { self .n_seq_max } ." )
702+ self .batch .n_seq_id [idx ] = n_seq_id
703+
704+ for i , seq_id in enumerate (seq_ids ):
705+ self .batch .seq_id [idx ][i ] = seq_id
706+ self .batch .logits [idx ] = logits
707+
708+ self .batch .n_tokens += 1
709+
710+ def add_sequence (
711+ self ,
712+ token_array : Sequence [int ],
713+ pos_array : Sequence [int ],
714+ seq_ids : Sequence [Sequence [int ]],
715+ logits_array : Sequence [bool ]
716+ ):
717+ """
718+ Adds a sequence of tokens to the batch in a vectorized manner.
719+ Strictly maps the provided arrays to the underlying C++ batch structure without subjective overriding.
720+
721+ Args:
722+ token_array: A sequence of token IDs to be evaluated.
723+ pos_array: A sequence of logical positions corresponding to each token.
724+ seq_id_array: A sequence of lists, where each list contains the sequence IDs for the respective token.
725+ (e.g., [[0], [0], [0]] for 3 tokens belonging to sequence 0).
726+ logits_array: A sequence of boolean flags indicating whether to compute logits for each token.
727+ """
728+ n_tokens = len (token_array )
694729 current_count = self .batch .n_tokens
730+
695731 if current_count + n_tokens > self .n_tokens_capacity :
696732 raise IndexError (
697- f"LlamaBatch overflow: Cannot add { n_tokens } tokens. "
733+ f"LlamaBatch overflow[add_sequence] : Cannot add { n_tokens } tokens. "
698734 f"Space left: { self .n_tokens_capacity - current_count } "
699735 )
700- self .batch .n_tokens += n_tokens
736+
737+ n_seq_id = len (seq_ids )
738+ if n_seq_id > self .n_seq_max :
739+ raise ValueError (f"LlamaBatch Error[add_sequence]: Token belongs to { n_seq_id } sequences, "
740+ f"but n_seq_max was initialized to { self .n_seq_max } ." )
741+
701742 for i in range (n_tokens ):
702743 j = current_count + i
703- self .batch .token [j ] = batch [i ]
704- self .batch .pos [j ] = i
705- self .batch .seq_id [j ][0 ] = seq_id
706- self .batch .n_seq_id [j ] = 1
707- self .batch .logits [j ] = logits_all
708- self .batch .logits [current_count + n_tokens - 1 ] = True
744+ self .batch .token [j ] = token_array [i ]
745+ self .batch .pos [j ] = pos_array [i ]
746+
747+ self .batch .n_seq_id [j ] = n_seq_id
748+ for k , seq_id in enumerate (seq_ids ):
749+ self .batch .seq_id [j ][k ] = seq_id
750+
751+ self .batch .logits [j ] = logits_array [i ]
752+
753+ self .batch .n_tokens += n_tokens
709754
710755
711756# Embedding functions
0 commit comments