[Feature request] High-level API support for DRY and XTC samplers
Opened this issue · 3 comments
Is your feature request related to a problem? Please describe.
Recently llama.cpp added support for the DRY and XTC samplers which can help reduce repetition and increase creativity without losing coherence. It would be wonderful if users of llama-cpp-python could take advantage of these advanced samplers.
Describe the solution you'd like
Ideally the high-level API would expose parameters so that the end user / developer may use XTC and DRY, in the same way that we can currently use temperature, top-p, min-p, etc. Functions like Llama.create_completion()
would be updated with these new parameters.
Additional context
I would be happy to help in any way I can with the implementation of these samplers, but I'm not sure where to start. @abetlen If there is anything I can do to help get this supported as quickly as possible, please point me in the right direction.
Thank you!
This patch adds the DRY and XTC samplers:
diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py
index 0aff348..fb78d31 100644
--- a/llama_cpp/_internals.py
+++ b/llama_cpp/_internals.py
@@ -800,6 +800,22 @@ class LlamaSampler:
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
self._add_sampler(sampler)
+ def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int):
+ sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
+ self._add_sampler(sampler)
+
+ def add_dry(self, model: LlamaModel, multiplier: float, base: float,
+ allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):
+
+ # Convert Python strings to bytes
+ seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
+ # Create array of char*
+ arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
+ sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base,
+ allowed_length, penalty_last_n,
+ arr, len(seq_breakers))
+ self._add_sampler(sampler)
+
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
sampler = llama_cpp.llama_sampler_init_grammar(
model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index babb30c..42febcd 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -677,6 +677,13 @@ class Llama:
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
+ xtc_probability: float = 0.0,
+ xtc_threshold: float = 0.1,
+ dry_multiplier: float = 0.0,
+ dry_allowed_length: int = 2,
+ dry_base: float = 1.75,
+ dry_range: int = 0,
+ dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
@@ -744,6 +751,7 @@ class Llama:
else:
n_probs = 0
min_keep = max(1, n_probs)
+ sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
sampler.add_top_k(top_k)
sampler.add_tail_free(tfs_z, min_keep)
sampler.add_typical(typical_p, min_keep)
@@ -751,6 +759,7 @@ class Llama:
sampler.add_min_p(min_p, min_keep)
sampler.add_temp(temp)
sampler.add_dist(self._seed)
+ sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
return sampler
def sample(
@@ -767,6 +776,13 @@ class Llama:
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
+ xtc_probability: float = 0.0,
+ xtc_threshold: float = 0.1,
+ dry_multiplier: float = 0.0,
+ dry_allowed_length: int = 2,
+ dry_base: float = 1.75,
+ dry_range: int = 0,
+ dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
@@ -802,6 +818,13 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
+ xtc_probability=xtc_probability,
+ xtc_threshold=xtc_threshold,
+ dry_multiplier=dry_multiplier,
+ dry_allowed_length=dry_allowed_length,
+ dry_base=dry_base,
+ dry_range=dry_range,
+ dry_seq_breakers=dry_seq_breakers,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
@@ -831,6 +854,13 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
+ xtc_probability: float = 0.0,
+ xtc_threshold: float = 0.1,
+ dry_multiplier: float = 0.0,
+ dry_allowed_length: int = 2,
+ dry_base: float = 1.75,
+ dry_range: int = 0,
+ dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
@@ -870,6 +900,13 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
+ xtc_probability=xtc_probability,
+ xtc_threshold=xtc_threshold,
+ dry_multiplier=dry_multiplier,
+ dry_allowed_length=dry_allowed_length,
+ dry_base=dry_base,
+ dry_range=dry_range,
+ dry_seq_breakers=dry_seq_breakers,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
@@ -922,6 +959,13 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
+ xtc_probability=xtc_probability,
+ xtc_threshold=xtc_threshold,
+ dry_multiplier=dry_multiplier,
+ dry_allowed_length=dry_allowed_length,
+ dry_base=dry_base,
+ dry_range=dry_range,
+ dry_seq_breakers=dry_seq_breakers,
logits_processor=logits_processor,
grammar=grammar,
penalize_nl=penalize_nl,
@@ -1138,6 +1182,13 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
+ xtc_probability: float = 0.0,
+ xtc_threshold: float = 0.1,
+ dry_multiplier: float = 0.0,
+ dry_allowed_length: int = 2,
+ dry_base: float = 1.75,
+ dry_range: int = 0,
+ dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1326,6 +1377,13 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
+ xtc_probability=xtc_probability,
+ xtc_threshold=xtc_threshold,
+ dry_multiplier=dry_multiplier,
+ dry_allowed_length=dry_allowed_length,
+ dry_base=dry_base,
+ dry_range=dry_range,
+ dry_seq_breakers=dry_seq_breakers,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
@@ -1758,6 +1816,13 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
+ xtc_probability: float = 0.0,
+ xtc_threshold: float = 0.1,
+ dry_multiplier: float = 0.0,
+ dry_allowed_length: int = 2,
+ dry_base: float = 1.75,
+ dry_range: int = 0,
+ dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1821,6 +1886,13 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
+ xtc_probability=xtc_probability,
+ xtc_threshold=xtc_threshold,
+ dry_multiplier=dry_multiplier,
+ dry_allowed_length=dry_allowed_length,
+ dry_base=dry_base,
+ dry_range=dry_range,
+ dry_seq_breakers=dry_seq_breakers,
model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
@@ -1855,6 +1927,13 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
+ xtc_probability: float = 0.0,
+ xtc_threshold: float = 0.1,
+ dry_multiplier: float = 0.0,
+ dry_allowed_length: int = 2,
+ dry_base: float = 1.75,
+ dry_range: int = 0,
+ dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1918,6 +1997,13 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
+ xtc_probability=xtc_probability,
+ xtc_threshold=xtc_threshold,
+ dry_multiplier=dry_multiplier,
+ dry_allowed_length=dry_allowed_length,
+ dry_base=dry_base,
+ dry_range=dry_range,
+ dry_seq_breakers=dry_seq_breakers,
model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
@@ -1949,6 +2035,13 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
+ xtc_probability: float = 0.0,
+ xtc_threshold: float = 0.1,
+ dry_multiplier: float = 0.0,
+ dry_allowed_length: int = 2,
+ dry_base: float = 1.75,
+ dry_range: int = 0,
+ dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
@@ -2022,6 +2115,13 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
+ xtc_probability=xtc_probability,
+ xtc_threshold=xtc_threshold,
+ dry_multiplier=dry_multiplier,
+ dry_allowed_length=dry_allowed_length,
+ dry_base=dry_base,
+ dry_range=dry_range,
+ dry_seq_breakers=dry_seq_breakers,
model=model,
logits_processor=logits_processor,
grammar=grammar,
diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py
index 66feed8..dbc9cea 100644
--- a/llama_cpp/llama_cpp.py
+++ b/llama_cpp/llama_cpp.py
@@ -3244,6 +3244,38 @@ def llama_sampler_init_xtc(
) -> llama_sampler_p:
...
+# LLAMA_API struct llama_sampler * llama_sampler_init_dry(
+# const struct llama_model * model,
+# float dry_multiplier,
+# float dry_base,
+# int32_t dry_allowed_length,
+# int32_t dry_penalty_last_n,
+# const char ** seq_breakers,
+# size_t num_breakers);
+@ctypes_function(
+"llama_sampler_init_dry",
+ [
+ llama_model_p_ctypes,
+ ctypes.c_float,
+ ctypes.c_float,
+ ctypes.c_int32,
+ ctypes.c_int32,
+ ctypes.POINTER(ctypes.c_char_p),
+ ctypes.c_size_t
+ ],
+ llama_sampler_p_ctypes,
+)
+def llama_sampler_init_dry(
+ model: llama_model_p,
+ dry_multiplier: float,
+ dry_base: float,
+ dry_allowed_length: int,
+ dry_penalty_last_n: int,
+ seq_breakers: list[str],
+ num_breakers: int,
+) -> llama_sampler_p:
+ ...
+
# /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@zpin feel free to open a PR !