Understanding Logits Processors

code
analysis
Author

Alonso Silva

Published

July 16, 2025

Logits processors are incredibly powerful and I think they should receive more attention from the community. Logits processors, as their name implies, process the logits, that is, they process the outputs of the last layer of the neural network or the raw scores of the tokens. We can modify the raw scores and get a completely different result than the one the language model would have generated on its own. We will clarify this with some examples.

In this post, we will see some simple logits processors examples (minimum length and minimum new tokens length), as well as some more complex ones (replacing the end of sequence by a word and replacing the end of sequence by a phrase). We then conclude with two practical applications of logits processors: make reasoning models stop thinking after a limit by specifying a thinking budget as well as forcing reasoning models to think for a longer time for particularly difficult questions.

Basic Example

Let’s start with a basic example of a logit processor. In this section, we won’t use the thinking capabilities of the language model.

We first download a small language model (0.6B parameters) and its tokenizer.

Show the code
import torch
from typing import List
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.generation import LogitsProcessor

model_id = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(
    model_id, cache_dir="/big_storage/llms/hf_models/"
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)

We can ask the question:

What’s 2 + 2?

and see what the language model would have responded without any logits processor.

Show the code
user_input = "What's 2 + 2?"

def generate_response(user_input, logits_processor=[], enable_thinking=False):
    messages = [
        {"role": "user", "content": user_input},
    ]

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking,
    )

    model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    prompt_length = model_inputs['input_ids'].shape[-1]

    generation_kwargs = dict(
        model_inputs,
        streamer=streamer,
        logits_processor=logits_processor,
        max_new_tokens=4 * 1024,
        do_sample=False,
        temperature=1.0,
        top_p=1.0,
        top_k=50,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    assistant_response = ""
    for chunk in streamer:
        assistant_response += chunk
        # print(chunk, end="")

    clean_assistant_response = assistant_response.split("<|im_end|>")[0]

    if enable_thinking:
        reasoning_trace = assistant_response.split("<think>")[-1].split("</think>")[0]
        thinking_length = len(tokenizer.encode(reasoning_trace))
        if "</think>" in assistant_response:
            response_without_reasoning_trace = assistant_response.split("</think>")[-1]
            response_length = len(tokenizer.encode(response_without_reasoning_trace))
        else:
            response_length = 0
    else:
        thinking_length = 0
        response_length = len(tokenizer.encode(clean_assistant_response))
    thread.join()
    return clean_assistant_response, prompt_length, thinking_length, response_length


assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input
)
print(assistant_response)
2 + 2 equals 4.

The answer is quite straightforward: 2 + 2 equals 4.

The number of tokens is the following:

Show the code
print(
    f"# prompt tokens: {prompt_length}\n# thinking tokens: {thinking_length}\n# response tokens: {response_length}"
)
# prompt tokens: 20
# thinking tokens: 0
# response tokens: 8

There are \(20\) prompt tokens and \(8\) response tokens. Here are the \(8\) response tokens:

Show the code
for token in tokenizer.encode(assistant_response):
    print(f"id: {token}; token: {tokenizer.decode(token).replace(' ', '⎵')}")
id: 17; token: 2
id: 488; token: ⎵+
id: 220; token: ⎵
id: 17; token: 2
id: 16819; token: ⎵equals
id: 220; token: ⎵
id: 19; token: 4
id: 13; token: .

Minimum Length

Let’s force the language model to generate a longer answer. In order to do that, we can define the following logits processor:

class MinLengthLogitsProcessor(LogitsProcessor):
    def __init__(self, min_length: int, eos_token_ids: List[int]):
        
        self.min_length = min_length
        self.eos_token_ids = eos_token_ids

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        
        scores_processed = scores.clone()
        token_count = input_ids.shape[-1]
        if token_count < self.min_length:
            for eos_token_id in self.eos_token_ids:
                scores_processed[:, eos_token_id] = -torch.inf
        return scores_processed

This logits processor consists of two parts:

  • The first is the constructor which just initializes the minimum length required and the list of end of sequence tokens.
  • The second is the callable method which clones the original scores, and if we haven’t yet reached the minimum length required, it will give a score of minus infinite to the end of sequence tokens, effectively preventing the language model to choose them and therefore preventing it from ending the sentence. The language model will need to continue talking for as long as we want. And that’s what we will see.

Let’s instantiate this logits processor with a required minimum length of \(40\):

logits_processor = [
    MinLengthLogitsProcessor(
        min_length=40, eos_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id]
    )
]

This is the assistant response:

Show the code
assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input,
    logits_processor=logits_processor,
)
print(assistant_response)
2 + 2 equals 4. Let me know if you have any other questions! 😊

The language model added to the previous response 2 + 2 equals 4. the phrase Let me know if you have any other questions! 😊

We made the language model do that!

The number of tokens is the following:

Show the code
print(
    f"# prompt tokens: {prompt_length}\n# thinking tokens: {thinking_length}\n# response tokens: {response_length}"
)
# prompt tokens: 20
# thinking tokens: 0
# response tokens: 20

The language model generated \(20\) tokens even though we asked for a minimum length of \(40\) tokens. The reason is that the logits processor also considers the \(20\) prompt tokens and \(20+20\ge40\) so that’s correct.

Minimum New Tokens Length

Let’s remove the prompt tokens from the computation. The logits processor is slightly more complex:

class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
    def __init__(
        self, min_new_tokens_length: int, eos_token_ids: List[int]
    ):
        self.min_new_tokens_length = min_new_tokens_length
        self.eos_token_ids = eos_token_ids
        self.prompt_length_to_skip = None

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        first_time = self.prompt_length_to_skip is None
        if first_time:
            self.prompt_length_to_skip = input_ids.shape[-1]
        scores_processed = scores.clone()
        token_count = input_ids.shape[-1] - self.prompt_length_to_skip
        if token_count < self.min_new_tokens_length:
            for eos_token_id in self.eos_token_ids:
                scores_processed[:, eos_token_id] = -torch.inf
        return scores_processed

We have added a prompt_length_to_skip which will get its value from the length of the input ids only the first time the logits processor is called, effectively storing the prompt length. We then substract the prompt_length_to_skip from the token_count.

Let’s instantiate this logits processor with a required minimum length of new tokens of \(40\).

logits_processor = [
    MinNewTokensLengthLogitsProcessor(
        min_new_tokens_length=40,
        eos_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id],
    )
]

This is the assistant response:

Show the code
assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input,
    logits_processor=logits_processor
)
print(assistant_response)
2 + 2 equals 4. Let me know if you have any other questions! 😊. 🎉. 🔍. 🧠. 🧠. 🧠.

The language model added to the previous phrase 2 + 2 equals 4. Let me know if you have any other questions! 😊 the following phrase . 🎉. 🔍. 🧠. 🧠. 🧠.

The number of tokens is the following:

Show the code
print(
    f"# prompt tokens: {prompt_length}\n# thinking tokens: {thinking_length}\n# response tokens: {response_length}"
)
# prompt tokens: 20
# thinking tokens: 0
# response tokens: 40

The language model added some emojis in order to arrive to the required \(40\) response tokens. That’s ok.

Replacements

Replace the end of sequence by a word

Now let’s do something slightly more complex. This time when the language model wants to finish its answer (in our case, after the phrase 2 + 2 equals 4.), we are going to replace the ending token with another token. In this case, with the token ⎵Heck (my first choice was the F-word). Note that we can also replace any other token and see how the model would have continued the phrase.

class MinNewTokensLengthWithReplacementTokenLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        min_new_tokens_length: int,
        eos_token_ids: List[int],
        replacement_token_id: int
    ):
        self.min_new_tokens_length = min_new_tokens_length
        self.eos_token_ids = eos_token_ids
        self.replacement_token_id = replacement_token_id
        self.prompt_length_to_skip = None
        self.very_large_number = 10_000

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        first_time = self.prompt_length_to_skip is None
        if first_time:
            self.prompt_length_to_skip = input_ids.shape[-1]
        scores_processed = scores.clone()
        token_count = input_ids.shape[-1] - self.prompt_length_to_skip
        if token_count < self.min_new_tokens_length:
            token_chosen_id = torch.argmax(scores_processed).item()
            if token_chosen_id in self.eos_token_ids:
                scores_processed[:, self.replacement_token_id] = self.very_large_number
                for eos_token_id in self.eos_token_ids:
                    scores_processed[:, eos_token_id] = -torch.inf
        return scores_processed

Let’s instantiate this logits processor:

logits_processor=[
    MinNewTokensLengthWithReplacementTokenLogitsProcessor(
        min_new_tokens_length=40,
        eos_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id],
        replacement_token_id=tokenizer.encode(" Heck")[0],
    )
]

Here is the assistant response:

Show the code
assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input,
    logits_processor=logits_processor
)
print(assistant_response)
2 + 2 equals 4. Heck, that's a simple math problem. Heck, I'm just a AI assistant here. Heck, I'm not going to do that. Heck, I'm just going to tell you that 2 + 2 is 4.

The language model added to the very first phrase 2 + 2 equals 4. the following phrases Heck, that's a simple math problem. Heck, I'm just a AI assistant here. Heck, I'm not going to do that. Heck, I'm just going to tell you that 2 + 2 is 4.

There are completely different phrases with repect to the previous ones!

By modifying the end of sequence by the token ⎵Heck we made the model take a completely different path compared to what the language model would have taken by itself.

Replace the end of sequence by a phrase

We just replaced the end of sequence by a word (token) but it might be interesting to replace the end of sequence by a phrase. For example, we might want that the language model checks its answer. Let’s do that by replacing the end of sentence by the phrase Wait, let me check my answer.

The logits processor is slightly more complex since we need to generate a sequence of tokens and therefore find a way to keep the state (here the state will be kept by the index variable):

class MinNewTokensLengthWithReplacementLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        min_new_tokens_length: int,
        eos_token_ids: List[int],
        replacement_tokens_ids: List[int],
    ):
        self.min_new_tokens_length = min_new_tokens_length
        self.eos_token_ids = eos_token_ids
        self.replacement_tokens_ids = replacement_tokens_ids
        self.prompt_length_to_skip = None
        self.very_large_number = 10_000
        self.index = -1

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        first_time = self.prompt_length_to_skip is None
        if first_time:
            self.prompt_length_to_skip = input_ids.shape[-1]
        scores_processed = scores.clone()
        token_count = input_ids.shape[-1] - self.prompt_length_to_skip
        if token_count < self.min_new_tokens_length:
            token_chosen_id = torch.argmax(scores_processed).item()
            if (token_chosen_id in self.eos_token_ids) and (self.index == -1):
                for eos_token_id in self.eos_token_ids:
                    scores_processed[:, eos_token_id] = -torch.inf
                self.index = 0

            if len(self.replacement_tokens_ids) > self.index >= 0:
                scores_processed[:, self.replacement_tokens_ids[self.index]] = (
                    self.very_large_number
                )
                self.index += 1

            if self.index == len(self.replacement_tokens_ids):
                self.index = -1

        return scores_processed

Let’s instantiate the logits processor:

logits_processor = [
    MinNewTokensLengthWithReplacementLogitsProcessor(
        min_new_tokens_length=30,
        eos_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id],
        replacement_tokens_ids=tokenizer.encode(" Wait, let me check my answer")
    )
]
Show the code
assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input,
    logits_processor=logits_processor
)
print(assistant_response)
2 + 2 equals 4. Wait, let me check my answer again. 2 + 2 is indeed 4. So the correct answer is 4.

The language model added to the first phrase 2 + 2 equals 4. the phrases Wait, let me check my answer again. 2 + 2 is indeed 4. So the correct answer is 4.

We forced the model to check its answer by replacing the end of sequence by the phrase Wait, let me check my answer and let the language model continue the phrase. That’s great!

Thinking Budget

The previous sections were fun for me and I hope they were fun for you as well, but let’s now look at very practical applications of logits processors.

Many people noticed that reasoning models are very verbose in their thinking and they were looking for practical ways to limit that. Qwen3 even provided the choice to remove thinking altogether for some questions (btw that’s what we did here by putting enable_thinking=False). However, what if we want to let the model think but not for too long. Couldn’t we define a thinking budget and if the model goes above that thinking budget you just make the thinking stop altogether?

Of course, we can, thanks to logits processors!

If we ask the question What's 2 + 2? and let the language model think (enable_thinking=True) without any constraint, this is what we get:

Show the code
assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input,
    enable_thinking=True
)
print(assistant_response)
<think>
Okay, the user is asking, "What's 2 + 2?" Let me think about how to approach this. First, I need to make sure I understand the question correctly. The user is probably looking for the sum of 2 plus 2, which is 4. But maybe they're trying to get a different answer, like a joke or something else. Let me check if there's any context I'm missing.

Wait, sometimes people use "2 + 2" in a different way. For example, in some languages, numbers are written differently, but in English, it's straightforward. Also, maybe the user is testing if I can recognize that 2 + 2 equals 4. But I should also consider if there's any trick here. For instance, if they're using a calculator, the result would be 4. But since the question is simple, the answer is 4.

I should also make sure there's no hidden meaning or cultural context. In most basic math problems, 2 + 2 is 4. So the answer is 4. I don't see any other possible interpretations here. The user might just want the direct answer. Let me confirm once more. Yes, 2 plus 2 is 4. So the final answer should be 4.
</think>

2 + 2 equals 4.

The number of tokens is the following:

Show the code
print(
    f"# prompt tokens: {prompt_length}\n# thinking tokens: {thinking_length}\n# response tokens: {response_length}"
)
# prompt tokens: 16
# thinking tokens: 269
# response tokens: 10

If we let the language model think without any constraint, it uses \(269\) thinking tokens.

Let’s consider the case when we force the language model to think for less than a fixed thinking budget. The logits processor would be:

class ThinkingBudgetLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        thinking_budget: int,
        eot_token_id: int
    ):
        self.thinking_budget = thinking_budget
        self.eot_token_id = eot_token_id
        self.prompt_length_to_skip = None
        self.very_large_number = 10_000

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        first_time = self.prompt_length_to_skip is None
        if first_time:
            self.prompt_length_to_skip = input_ids.shape[-1]
        scores_processed = scores.clone()
        token_count = input_ids.shape[-1] - self.prompt_length_to_skip
        if token_count == self.thinking_budget:
            scores_processed[:, self.eot_token_id] = self.very_large_number
        return scores_processed

We can instantiate the logits processor with a thinking budget of \(100\):

logits_processor = [
    ThinkingBudgetLogitsProcessor(
        thinking_budget=100,
        eot_token_id=tokenizer.encode("</think>")[0],
    )
]
Show the code
assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input,
    logits_processor=logits_processor,
    enable_thinking=True
)

print(assistant_response)
<think>
Okay, the user is asking, "What's 2 + 2?" Let me think about how to approach this. First, I need to make sure I understand the question correctly. The user is probably looking for the sum of 2 plus 2, which is 4. But maybe they're trying to get a different answer, like a joke or something else. Let me check if there's any context I'm missing.

Wait, sometimes people use "2 + 2</think>

The answer is 4.

The number of tokens is:

Show the code
print(
    f"# prompt tokens: {prompt_length}\n# thinking tokens: {thinking_length}\n# response tokens: {response_length}"
)
# prompt tokens: 16
# thinking tokens: 99
# response tokens: 8

We have forced the model to stop thinking after a fixed number of thinking budget and then provide a response. This is very convenient if we want to limit the verbosity of reasoning models.

Budget Forcing

Another application of logits processors is budget forcing. Budget forcing consists of forcing the model to continue thinking for a longer time for particularly difficult problems by appending the token Wait when the language model wants to stop thinking. The idea comes from the paper s1: Simple test-time scaling which claims that budget forcing improves the language model accuracy from \(50\%\) to \(57\%\) in AIME 2024 dataset. In that paper, the authors force the model to continue thinking by replacing the stop thinking token (</think>) by the token Wait.

Using logits processors, we can make it more general by replacing the stop thinking token (</think>) by a phrase, for example the phrase Wait, let me check my answer.

Similar to the previous section, the logits processor would be:

class BudgetForcingLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        thinking_budget: int,
        eot_token_id: int,
        replacement_tokens_ids: int,
        eos_token_ids: List[int],
        device: str = "cuda",
    ):
        self.thinking_budget = thinking_budget
        self.eot_token_id = eot_token_id
        self.replacement_tokens_ids = replacement_tokens_ids
        self.eos_token_ids = eos_token_ids
        self.prompt_length_to_skip = None
        self.very_large_number = 10_000
        self.index = -1

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        first_time = self.prompt_length_to_skip is None
        if first_time:
            self.prompt_length_to_skip = input_ids.shape[-1]
        scores_processed = scores.clone()
        token_count = input_ids.shape[-1] - self.prompt_length_to_skip
        if token_count < self.thinking_budget:
            token_chosen_id = torch.argmax(scores_processed).item()
            if token_chosen_id in self.eos_token_ids:
                for eos_token_id in self.eos_token_ids:
                    scores_processed[:, eos_token_id] = -torch.inf
            if (token_chosen_id == self.eot_token_id) and (self.index == -1):
                scores_processed[:, self.eot_token_id] = -torch.inf
                self.index = 0
            if len(self.replacement_tokens_ids) > self.index >= 0:
                scores_processed[:, self.replacement_tokens_ids[self.index]] = (
                    self.very_large_number
                )
                self.index += 1
            if self.index == len(self.replacement_tokens_ids):
                self.index = -1
        return scores_processed

Without a logits processor, we saw that for the question What's 2 + 2?, the model would think for \(269\) tokens. Let’s instantiate the logits processor with a forced thinking budget of \(400\):

logits_processor = [
    BudgetForcingLogitsProcessor(
        thinking_budget=400,
        eot_token_id=tokenizer.encode("</think>")[0],
        replacement_tokens_ids=tokenizer.encode(" Wait, let me check my answer"),
        eos_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id],
    )
]

Here is the response:

Show the code
assistant_response, prompt_length, thinking_length, response_length = generate_response(
    user_input,
    logits_processor=logits_processor,
    enable_thinking=True
)
print(assistant_response)
<think>
Okay, the user is asking, "What's 2 + 2?" Let me think about how to approach this. First, I need to make sure I understand the question correctly. The user is probably looking for the sum of 2 plus 2, which is 4. But maybe they're trying to get a different answer, like a joke or something else. Let me check if there's any context I'm missing.

Wait, sometimes people use "2 + 2" in a different way. For example, in some languages, numbers are written differently, but in English, it's straightforward. Also, maybe the user is testing if I can recognize that 2 + 2 equals 4. But I should also consider if there's any trick here. For instance, if they're using a calculator, the result would be 4. But since the question is simple, the answer is 4.

I should also make sure there's no hidden meaning or cultural context. In most basic math problems, 2 + 2 is 4. So the answer is 4. I don't see any other possible interpretations here. The user might just want the direct answer. Let me confirm once more. Yes, 2 plus 2 is 4. So the final answer should be 4.
 Wait, let me check my answer again. 2 + 2 is definitely 4. I think that's correct. No tricks here. The user might be testing if I can do basic arithmetic. So yeah, the answer is 4.
 Wait, let me check my answer again. 2 + 2 is 4. I think that's correct. No tricks here. The user might be testing if I can do basic arithmetic. So the answer is 4.
 Wait, let me check my answer again. 2 + 2 is 4. I think that's correct. No tricks here. The user might be testing if I can do basic arithmetic. So the answer is 4.
 Okay, I think I'm confident with this answer. The user is probably just looking for the sum of 2 plus 2, which is 4.
</think>

2 + 2 equals 4.

Here are the tokens:

Show the code
print(
    f"# prompt tokens: {prompt_length}\n# thinking tokens: {thinking_length}\n# response tokens: {response_length}"
)
# prompt tokens: 16
# thinking tokens: 445
# response tokens: 10

We managed to make the language model to think for a longer time. The answer did not change for this simple question but it would be interesting to see how the answer changes for other questions like in the s1: Simple test-time scaling paper.

It would be interesting to study if other phrases make the model take completely different reasoning paths and if those reasoning paths improve the language model accuracy.

In this post, we have seen what logits processors are, how to create them, how to use them, as well as some practical applications. There are many more interesting applications and we are just scratching the surface of what’s possible.

References

  • logits-processor-zoo: You can learn a lot by looking at the code in this repo. The trick to get the prompt length as well as the trick for the replacing phrase are taking from it.