Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 242 additions & 47 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
float * llama_context::get_sampled_logits_ith(int32_t idx) {
output_reorder();

if (sampling.logits == nullptr) {
if (!has_sampled) {
return nullptr;
}

Expand Down Expand Up @@ -873,7 +873,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
size_t llama_context::get_sampled_logits_count(int32_t idx) {
output_reorder();

if (sampling.logits == nullptr) {
if (!has_sampled) {
return model.vocab.n_tokens();
}

Expand Down Expand Up @@ -1746,11 +1746,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
has_embd = true;
}

// Check which sampling modes are needed for the current batch.
// TODO: avoid this branching by working with the worst-case
bool has_sampling = false;
bool cpu_logits = false;

has_sampled = false;
bool cpu_logits = false;
if (batch.logits) {
for (int32_t i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
Expand All @@ -1759,7 +1756,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
llama_seq_id seq_id = batch.seq_id[i][j];
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
has_sampling = true;
has_sampled = true;
} else {
cpu_logits = true;
}
Expand All @@ -1778,21 +1775,13 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;

// TODO: avoid this branching by working with the worst-case
if (!has_sampling) {
sampling.logits_size = 0;
sampling.probs_size = 0;
sampling.sampled_size = 0;
sampling.candidates_size = 0;
} else {
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
sampling.candidates_size = n_vocab*n_outputs_max;
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
sampling.candidates_size = n_vocab*n_outputs_max;

backend_float_count = sampling.logits_size + sampling.probs_size;
backend_token_count = sampling.sampled_size + sampling.candidates_size;
}
backend_float_count = sampling.logits_size + sampling.probs_size;
backend_token_count = sampling.sampled_size + sampling.candidates_size;

if (output_ids.empty()) {
// init, never resized afterwards
Expand Down Expand Up @@ -1848,37 +1837,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
embd = has_embd ? (float *) (base + offset) : nullptr;
offset += embd_size * sizeof(float);

sampling.logits = nullptr;
sampling.probs = nullptr;
sampling.sampled = nullptr;
sampling.candidates = nullptr;

if (has_sampling) {
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);

sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);
sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);

sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);
sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);

sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);
sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);

// The count vectors keep track of the actual number of logits/probs/candidates
// copied from the backend for each output row.
// The count vectors keep track of the actual number of logits/probs/candidates
// copied from the backend for each output row.

sampling.logits_count.resize(n_outputs_max);
sampling.probs_count.resize(n_outputs_max);
sampling.candidates_count.resize(n_outputs_max);
sampling.logits_count.resize(n_outputs_max);
sampling.probs_count.resize(n_outputs_max);
sampling.candidates_count.resize(n_outputs_max);

std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);

std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
}
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);

// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
Expand Down Expand Up @@ -1908,7 +1890,7 @@ void llama_context::output_reorder() {
}
}

if (sampling.logits && sampling.logits_size > 0) {
if (has_sampled && sampling.logits_size > 0) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
}
Expand Down Expand Up @@ -2501,6 +2483,100 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {

// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004
// Note: sampling.samplers map is not saved. Backend samplers are expected
// to be re-configured by the application with the same settings when loading state.

// write backend sampled logits
{
LLAMA_LOG_DEBUG("%s: - writing backend sampled logits\n", __func__);

const uint64_t logits_size = this->sampling.logits_size;
io.write(&logits_size, sizeof(logits_size));

if (logits_size) {
io.write(sampling.logits, logits_size * sizeof(float));
}
}

// write backend sampled tokens
{
LLAMA_LOG_DEBUG("%s: - writing backend sampled tokens\n", __func__);

const uint64_t sampled_size = this->sampling.sampled_size;
io.write(&sampled_size, sizeof(sampled_size));

if (sampled_size) {
io.write(sampling.sampled, sampled_size * sizeof(llama_token));
}
}

// write backend sampled probs
{
LLAMA_LOG_DEBUG("%s: - writing backend sampled probs\n", __func__);

const uint64_t probs_size = this->sampling.probs_size;
io.write(&probs_size, sizeof(probs_size));

if (probs_size) {
io.write(sampling.probs, probs_size * sizeof(float));
}
}

// write backend sampled candidates
{
LLAMA_LOG_DEBUG("%s: - writing backend sampled candidates\n", __func__);

const uint64_t candidates_size = this->sampling.candidates_size;
io.write(&candidates_size, sizeof(candidates_size));

if (candidates_size) {
io.write(sampling.candidates, candidates_size * sizeof(llama_token));
}
}

// write backend logits count
{
LLAMA_LOG_DEBUG("%s: - writing backend logits count\n", __func__);

const uint64_t count_size = this->sampling.logits_count.size();
io.write(&count_size, sizeof(count_size));

if (count_size) {
io.write(sampling.logits_count.data(), count_size * sizeof(uint32_t));
}
}

// write backend probs count
{
LLAMA_LOG_DEBUG("%s: - writing backend probs count\n", __func__);

const uint64_t count_size = this->sampling.probs_count.size();
io.write(&count_size, sizeof(count_size));

if (count_size) {
io.write(sampling.probs_count.data(), count_size * sizeof(uint32_t));
}
}

// write backend candidates count
{
LLAMA_LOG_DEBUG("%s: - writing backend candidates count\n", __func__);

const uint64_t count_size = this->sampling.candidates_count.size();
io.write(&count_size, sizeof(count_size));

if (count_size) {
io.write(sampling.candidates_count.data(), count_size * sizeof(uint32_t));
}
}

// write backend has_sampled flag
{
LLAMA_LOG_DEBUG("%s: - writing backend has_sampled flag\n", __func__);

io.write(&this->has_sampled, sizeof(bool));
}


if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
Expand Down Expand Up @@ -2594,6 +2670,125 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004

// read sampled logits
{
LLAMA_LOG_DEBUG("%s: - reading backend sampled logits\n", __func__);

uint64_t logits_size;
io.read_to(&logits_size, sizeof(logits_size));

if (this->sampling.logits_size < logits_size) {
throw std::runtime_error("logits buffer too small");
}

if (logits_size) {
io.read_to(this->sampling.logits, logits_size * sizeof(float));
}
}

// read sampled tokens
{
LLAMA_LOG_DEBUG("%s: - reading backend sampled tokens\n", __func__);

uint64_t sampled_size;
io.read_to(&sampled_size, sizeof(sampled_size));

if (this->sampling.sampled_size < sampled_size) {
throw std::runtime_error("sampled buffer too small");
}

if (sampled_size) {
io.read_to(this->sampling.sampled, sampled_size * sizeof(llama_token));
}
}

// read sampled probs
{
LLAMA_LOG_DEBUG("%s: - reading backend sampled probs\n", __func__);

uint64_t probs_size;
io.read_to(&probs_size, sizeof(probs_size));

if (this->sampling.probs_size < probs_size) {
throw std::runtime_error("probs buffer too small");
}

if (probs_size) {
io.read_to(this->sampling.probs, probs_size * sizeof(float));
}
}

// read sampled candidates
{
LLAMA_LOG_DEBUG("%s: - reading backend sampled candidates\n", __func__);

uint64_t candidates_size;
io.read_to(&candidates_size, sizeof(candidates_size));

if (this->sampling.candidates_size < candidates_size) {
throw std::runtime_error("candidates buffer too small");
}

if (candidates_size) {
io.read_to(this->sampling.candidates, candidates_size * sizeof(llama_token));
}
}

// read sampled logits count
{
LLAMA_LOG_DEBUG("%s: - reading backend sampled logits count\n", __func__);

uint64_t count_size;
io.read_to(&count_size, sizeof(count_size));

if (this->sampling.logits_count.size() < count_size) {
throw std::runtime_error("logits count buffer too small");
}

if (count_size) {
io.read_to(this->sampling.logits_count.data(), count_size * sizeof(uint32_t));
}
}

// read sampled probs count
{
LLAMA_LOG_DEBUG("%s: - reading backend sampled probs count\n", __func__);

uint64_t count_size;
io.read_to(&count_size, sizeof(count_size));

if (this->sampling.probs_count.size() < count_size) {
throw std::runtime_error("probs count buffer too small");
}

if (count_size) {
io.read_to(this->sampling.probs_count.data(), count_size * sizeof(uint32_t));
}
}

// read sampled candidates count
{
LLAMA_LOG_DEBUG("%s: - reading backend sampled candidates count\n", __func__);

uint64_t count_size;
io.read_to(&count_size, sizeof(count_size));

if (this->sampling.candidates_count.size() < count_size) {
throw std::runtime_error("candidates count buffer too small");
}

if (count_size) {
io.read_to(this->sampling.candidates_count.data(), count_size * sizeof(uint32_t));
}
}

// read has_sampled flag
{
LLAMA_LOG_DEBUG("%s: - reading backend has_sampled flag\n", __func__);

io.read_to(&has_sampled, sizeof(bool));
}

if (memory) {
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);

Expand Down
1 change: 1 addition & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ struct llama_context {
};

sampling_info sampling;
bool has_sampled = false;

// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
Expand Down
Loading
Loading