@@ -446,7 +446,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
446446 llama_save_session_file (ctx, path_session.c_str (), session_tokens.data (), session_tokens.size ());
447447 }
448448
449- const llama_token id = llama_sample_token (ctx, ctx_guidance, grammar, params, last_tokens, candidates);
449+ const llama_token id = llama_sample_token_binding (ctx, ctx_guidance, grammar, params_p, last_tokens, candidates);
450+ // const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
450451
451452 last_tokens.erase (last_tokens.begin ());
452453 last_tokens.push_back (id);
@@ -645,7 +646,9 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model
645646 int i_dft = 0 ;
646647 while (true ) {
647648 // sample from the target model
648- const llama_token id = llama_sample_token (ctx_tgt, NULL , grammar_tgt, params, last_tokens, candidates, i_dft);
649+
650+ // const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
651+ const llama_token id = llama_sample_token_binding (ctx_tgt, NULL , grammar_tgt, params_p, last_tokens, candidates, i_dft);
649652 // remember which tokens were sampled - used for repetition penalties during sampling
650653 last_tokens.erase (last_tokens.begin ());
651654 last_tokens.push_back (id);
@@ -965,6 +968,15 @@ struct llama_binding_state {
965968
966969void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity);
967970
971+ llama_token llama_sample_token_binding(
972+ struct llama_context * ctx,
973+ struct llama_context * ctx_guidance,
974+ struct llama_grammar * grammar,
975+ const struct gpt_params * g_params,
976+ const std::vector<llama_token> & last_tokens,
977+ std::vector<llama_token_data> & candidates,
978+ int idx = 0);
979+
968980common.cpp:
969981
970982gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base) {
@@ -1060,4 +1072,127 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f
10601072 state->model= model;
10611073 return state;
10621074}
1075+
1076+ // Note: the only difference here is passing params as a pointer and avoid copy-by-value
1077+ // We stick to another function to avoid patching all the llama.cpp code
1078+ // We need the function to be in the common.o object, as using it in the binding does not make effect.
1079+ llama_token llama_sample_token_binding(
1080+ struct llama_context * ctx,
1081+ struct llama_context * ctx_guidance,
1082+ struct llama_grammar * grammar,
1083+ const struct gpt_params * g_params, // NOTE: this is our patch
1084+ const std::vector<llama_token> & last_tokens,
1085+ std::vector<llama_token_data> & candidates,
1086+ int idx) {
1087+
1088+
1089+ struct gpt_params params = *g_params; // NOTE: this is our patch
1090+ const int n_ctx = llama_n_ctx(ctx);
1091+ const int n_vocab = llama_n_vocab(ctx);
1092+
1093+ const float temp = params.temp;
1094+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
1095+ const float top_p = params.top_p;
1096+ const float tfs_z = params.tfs_z;
1097+ const float typical_p = params.typical_p;
1098+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
1099+ const float repeat_penalty = params.repeat_penalty;
1100+ const float alpha_presence = params.presence_penalty;
1101+ const float alpha_frequency = params.frequency_penalty;
1102+ const int mirostat = params.mirostat;
1103+ const float mirostat_tau = params.mirostat_tau;
1104+ const float mirostat_eta = params.mirostat_eta;
1105+ const bool penalize_nl = params.penalize_nl;
1106+
1107+ llama_token id = 0;
1108+
1109+ float * logits = llama_get_logits(ctx) + idx * n_vocab;
1110+
1111+ // Apply params.logit_bias map
1112+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
1113+ logits[it->first] += it->second;
1114+ }
1115+
1116+ candidates.clear();
1117+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
1118+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
1119+ }
1120+
1121+ llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
1122+
1123+ if (ctx_guidance) {
1124+ llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
1125+ }
1126+
1127+ // apply penalties
1128+ if (!last_tokens.empty()) {
1129+ const float nl_logit = logits[llama_token_nl(ctx)];
1130+ const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
1131+
1132+ llama_sample_repetition_penalty(ctx, &cur_p,
1133+ last_tokens.data() + last_tokens.size() - last_n_repeat,
1134+ last_n_repeat, repeat_penalty);
1135+ llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
1136+ last_tokens.data() + last_tokens.size() - last_n_repeat,
1137+ last_n_repeat, alpha_frequency, alpha_presence);
1138+
1139+ if (!penalize_nl) {
1140+ for (size_t idx = 0; idx < cur_p.size; idx++) {
1141+ if (cur_p.data[idx].id == llama_token_nl(ctx)) {
1142+ cur_p.data[idx].logit = nl_logit;
1143+ break;
1144+ }
1145+ }
1146+ }
1147+ }
1148+
1149+ if (grammar != NULL) {
1150+ llama_sample_grammar(ctx, &cur_p, grammar);
1151+ }
1152+
1153+ if (temp <= 0) {
1154+ // Greedy sampling
1155+ id = llama_sample_token_greedy(ctx, &cur_p);
1156+ } else {
1157+ if (mirostat == 1) {
1158+ static float mirostat_mu = 2.0f * mirostat_tau;
1159+ const int mirostat_m = 100;
1160+ llama_sample_temperature(ctx, &cur_p, temp);
1161+ id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
1162+ } else if (mirostat == 2) {
1163+ static float mirostat_mu = 2.0f * mirostat_tau;
1164+ llama_sample_temperature(ctx, &cur_p, temp);
1165+ id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
1166+ } else {
1167+ // Temperature sampling
1168+ llama_sample_top_k (ctx, &cur_p, top_k, 1);
1169+ llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
1170+ llama_sample_typical (ctx, &cur_p, typical_p, 1);
1171+ llama_sample_top_p (ctx, &cur_p, top_p, 1);
1172+ llama_sample_temperature(ctx, &cur_p, temp);
1173+
1174+ {
1175+ const int n_top = 10;
1176+ LOG("top %d candidates:\n", n_top);
1177+
1178+ for (int i = 0; i < n_top; i++) {
1179+ const llama_token id = cur_p.data[i].id;
1180+ LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
1181+ }
1182+ }
1183+
1184+ id = llama_sample_token(ctx, &cur_p);
1185+
1186+ LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
1187+ }
1188+ }
1189+ // printf("`%d`", candidates_p.size);
1190+
1191+ if (grammar != NULL) {
1192+ llama_grammar_accept_token(ctx, grammar, id);
1193+ }
1194+
1195+ return id;
1196+ }
1197+
10631198*/
0 commit comments