abetlen/llama-cpp-python

[Feature request] High-level API support for DRY and XTC samplers

Opened this issue · 3 comments

ddh0 commented

Is your feature request related to a problem? Please describe.

Recently llama.cpp added support for the DRY and XTC samplers which can help reduce repetition and increase creativity without losing coherence. It would be wonderful if users of llama-cpp-python could take advantage of these advanced samplers.

Describe the solution you'd like

Ideally the high-level API would expose parameters so that the end user / developer may use XTC and DRY, in the same way that we can currently use temperature, top-p, min-p, etc. Functions like Llama.create_completion() would be updated with these new parameters.

Additional context

I would be happy to help in any way I can with the implementation of these samplers, but I'm not sure where to start. @abetlen If there is anything I can do to help get this supported as quickly as possible, please point me in the right direction.

Thank you!

zpin commented

This patch adds the DRY and XTC samplers:

diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py
index 0aff348..fb78d31 100644
--- a/llama_cpp/_internals.py
+++ b/llama_cpp/_internals.py
@@ -800,6 +800,22 @@ class LlamaSampler:
         sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
         self._add_sampler(sampler)
 
+    def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int):
+        sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
+        self._add_sampler(sampler)
+
+    def add_dry(self, model: LlamaModel, multiplier: float, base: float,
+                allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):
+
+        # Convert Python strings to bytes
+        seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
+        # Create array of char*
+        arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
+        sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base,
+                                                allowed_length, penalty_last_n,
+                                                arr, len(seq_breakers))
+        self._add_sampler(sampler)
+
     def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
         sampler = llama_cpp.llama_sampler_init_grammar(
             model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index babb30c..42febcd 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -677,6 +677,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_eta: float = 0.1,
         mirostat_tau: float = 5.0,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         penalize_nl: bool = True,
         logits_processor: Optional[LogitsProcessorList] = None,
         grammar: Optional[LlamaGrammar] = None,
@@ -744,6 +751,7 @@ class Llama:
             else:
                 n_probs = 0
                 min_keep = max(1, n_probs)
+                sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
                 sampler.add_top_k(top_k)
                 sampler.add_tail_free(tfs_z, min_keep)
                 sampler.add_typical(typical_p, min_keep)
@@ -751,6 +759,7 @@ class Llama:
                 sampler.add_min_p(min_p, min_keep)
                 sampler.add_temp(temp)
                 sampler.add_dist(self._seed)
+                sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
         return sampler
 
     def sample(
@@ -767,6 +776,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_eta: float = 0.1,
         mirostat_tau: float = 5.0,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         penalize_nl: bool = True,
         logits_processor: Optional[LogitsProcessorList] = None,
         grammar: Optional[LlamaGrammar] = None,
@@ -802,6 +818,13 @@ class Llama:
                 mirostat_mode=mirostat_mode,
                 mirostat_tau=mirostat_tau,
                 mirostat_eta=mirostat_eta,
+                xtc_probability=xtc_probability,
+                xtc_threshold=xtc_threshold,
+                dry_multiplier=dry_multiplier,
+                dry_allowed_length=dry_allowed_length,
+                dry_base=dry_base,
+                dry_range=dry_range,
+                dry_seq_breakers=dry_seq_breakers,
                 penalize_nl=penalize_nl,
                 logits_processor=logits_processor,
                 grammar=grammar,
@@ -831,6 +854,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         penalize_nl: bool = True,
         logits_processor: Optional[LogitsProcessorList] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
@@ -870,6 +900,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             penalize_nl=penalize_nl,
             logits_processor=logits_processor,
             grammar=grammar,
@@ -922,6 +959,13 @@ class Llama:
                     mirostat_mode=mirostat_mode,
                     mirostat_tau=mirostat_tau,
                     mirostat_eta=mirostat_eta,
+                    xtc_probability=xtc_probability,
+                    xtc_threshold=xtc_threshold,
+                    dry_multiplier=dry_multiplier,
+                    dry_allowed_length=dry_allowed_length,
+                    dry_base=dry_base,
+                    dry_range=dry_range,
+                    dry_seq_breakers=dry_seq_breakers,
                     logits_processor=logits_processor,
                     grammar=grammar,
                     penalize_nl=penalize_nl,
@@ -1138,6 +1182,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
@@ -1326,6 +1377,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             frequency_penalty=frequency_penalty,
             presence_penalty=presence_penalty,
             repeat_penalty=repeat_penalty,
@@ -1758,6 +1816,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
@@ -1821,6 +1886,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             model=model,
             stopping_criteria=stopping_criteria,
             logits_processor=logits_processor,
@@ -1855,6 +1927,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
@@ -1918,6 +1997,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             model=model,
             stopping_criteria=stopping_criteria,
             logits_processor=logits_processor,
@@ -1949,6 +2035,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
         grammar: Optional[LlamaGrammar] = None,
@@ -2022,6 +2115,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             model=model,
             logits_processor=logits_processor,
             grammar=grammar,
diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py
index 66feed8..dbc9cea 100644
--- a/llama_cpp/llama_cpp.py
+++ b/llama_cpp/llama_cpp.py
@@ -3244,6 +3244,38 @@ def llama_sampler_init_xtc(
 ) -> llama_sampler_p:
     ...
 
+#    LLAMA_API struct llama_sampler *    llama_sampler_init_dry(
+#            const struct llama_model *  model,
+#                               float    dry_multiplier,
+#                               float    dry_base,
+#                             int32_t    dry_allowed_length,
+#                             int32_t    dry_penalty_last_n,
+#                          const char ** seq_breakers,
+#                              size_t    num_breakers);
+@ctypes_function(
+"llama_sampler_init_dry",
+    [
+        llama_model_p_ctypes,
+        ctypes.c_float,
+        ctypes.c_float,
+        ctypes.c_int32,
+        ctypes.c_int32,
+        ctypes.POINTER(ctypes.c_char_p),
+        ctypes.c_size_t
+    ],
+    llama_sampler_p_ctypes,
+)
+def llama_sampler_init_dry(
+    model: llama_model_p,
+    dry_multiplier: float,
+    dry_base: float,
+    dry_allowed_length: int,
+    dry_penalty_last_n: int,
+    seq_breakers: list[str],
+    num_breakers: int,
+) -> llama_sampler_p:
+    ...
+
 
 # /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
 # /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.

@zpin feel free to open a PR !

zpin commented

Done: #1843