Skip to content

Reference

build_regex

Convert a Pydantic model or JSON schema to a regex.

Examples:

>>> from typing import Literal
>>> from pydantic import BaseModel, Field
>>> from litelines import build_regex
>>>
>>> class Sentiment(BaseModel):
...     "Correctly inferred `Sentiment` with all the required parameters with correct types."
...
...     label: Literal["positive", "negative"] = Field(
...         ..., description="Sentiment of the text"
...     )
>>> build_regex(Sentiment, whitespace_pattern="")
'\\{"label":("positive"|"negative")\\}'
>>> build_regex(Sentiment, whitespace_pattern="[ ]?")
'[ ]?\\{[ ]?"label"[ ]?:[ ]?("positive"|"negative")[ ]?\\}'
>>> build_regex(Sentiment)
'[\\n\\t ]*\\{[\\n\\t ]*"label"[\\n\\t ]*:[\\n\\t ]*("positive"|"negative")[\\n\\t ]*\\}'
>>> build_regex(Sentiment, include_tool_call=True, whitespace_pattern="")
'<tool_call>\\{"name":"Sentiment","arguments":\\{"label":("positive"|"negative")\\}\\}</tool_call>'

Parameters:

Name Type Description Default
schema Union[dict, str, Type[Any]]

The Pydantic model or JSON schema.

required
include_tool_call optional

Is the Pydantic model expecting a tool call or not.

False
tool_call_start optional

The expected tool call start.

'<tool_call>'
tool_call_end optional

The expected tool call end.

'</tool_call>'
whitespace_pattern optional

Pattern to use for JSON syntactic whitespace.

'[\\n\\t ]*'

Returns:

Type Description
str

The JSON schema converted to a regex.

Raises:

Type Description
ValueError

An error occurs if the schema is not a Pydantic model, a dictionary, or a string.

Source code in src/litelines/build_regex.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def build_regex(
    schema: Union[dict, str, Type[Any]],
    include_tool_call: bool = False,
    tool_call_start: str = "<tool_call>",
    tool_call_end: str = "</tool_call>",
    whitespace_pattern: str = r"[\n\t ]*",
) -> str:
    """Convert a Pydantic model or JSON schema to a regex.

    Examples:
        >>> from typing import Literal
        >>> from pydantic import BaseModel, Field
        >>> from litelines import build_regex
        >>>
        >>> class Sentiment(BaseModel):
        ...     "Correctly inferred `Sentiment` with all the required parameters with correct types."
        ...
        ...     label: Literal["positive", "negative"] = Field(
        ...         ..., description="Sentiment of the text"
        ...     )
        >>> build_regex(Sentiment, whitespace_pattern="")
        '\\\\{"label":("positive"|"negative")\\\\}'
        >>> build_regex(Sentiment, whitespace_pattern="[ ]?")
        '[ ]?\\\\{[ ]?"label"[ ]?:[ ]?("positive"|"negative")[ ]?\\\\}'
        >>> build_regex(Sentiment)
        '[\\\\n\\\\t ]*\\\\{[\\\\n\\\\t ]*"label"[\\\\n\\\\t ]*:[\\\\n\\\\t ]*("positive"|"negative")[\\\\n\\\\t ]*\\\\}'
        >>> build_regex(Sentiment, include_tool_call=True, whitespace_pattern="")
        '<tool_call>\\\\{"name":"Sentiment","arguments":\\\\{"label":("positive"|"negative")\\\\}\\\\}</tool_call>'

    Args:
        schema: The Pydantic model or JSON schema.
        include_tool_call (optional): Is the Pydantic model expecting a tool call or not.
        tool_call_start (optional): The expected tool call start.
        tool_call_end (optional): The expected tool call end.
        whitespace_pattern (optional): Pattern to use for JSON syntactic whitespace.

    Returns:
        The JSON schema converted to a regex.

    Raises:
        ValueError: An error occurs if the schema is not a Pydantic model, a dictionary, or a string.
    """
    if isinstance(schema, dict):
        schema_str = json.dumps(schema)
        name_str = schema["title"]
    elif isinstance(schema, str):
        schema_str = schema
        name_str = json.loads(schema)["title"]
    elif hasattr(schema, "model_json_schema"):
        schema_str = json.dumps(schema.model_json_schema())
        name_str = schema.__name__
    else:
        raise ValueError(
            f"Cannot parse schema {schema}. The schema must be either "
            + "a Pydantic class, a dictionary or a string that contains the JSON "
            + "schema specification"
        )
    _regex_str = build_regex_from_schema(
        schema_str, whitespace_pattern=whitespace_pattern
    )
    if include_tool_call:
        regex_str = (
            whitespace_pattern
            + tool_call_start
            + whitespace_pattern
            + "\\{"
            + whitespace_pattern
            + '"name"'
            + whitespace_pattern
            + ":"
            + whitespace_pattern
            + '"'
            + name_str
            + '"'
            + whitespace_pattern
            + ","
            + whitespace_pattern
            + '"arguments"'
            + whitespace_pattern
            + ":"
            + whitespace_pattern
            + _regex_str
            + whitespace_pattern
            + "\\}"
            + whitespace_pattern
            + tool_call_end
        )
    else:
        regex_str = whitespace_pattern + _regex_str
    return regex_str

build_dfa

Build a deterministic finite automaton that fullfils the response format requirement

Examples:

>>> from typing import Literal
>>> from pydantic import BaseModel, Field
>>> from transformers import AutoTokenizer
>>> from litelines import build_dfa
>>>
>>> model_id = "Qwen/Qwen3-0.6B"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> build_dfa("A|B", tokenizer)
{0: {33: 1, 32: 1}}
>>> build_dfa("A0|B0", tokenizer)
{1: {15: 3}, 2: {15: 3}, 0: {33: 1, 32: 2}}
>>>
>>> class Sentiment(BaseModel):
...     "Correctly inferred `Sentiment` with all the required parameters with correct types."
...
...     label: Literal["positive", "negative"] = Field(
...         ..., description="Sentiment of the text"
...     )
>>> build_dfa(Sentiment, tokenizer, whitespace_pattern="")
{18: {72: 15, 344: 17, 533: 16}, 9: {92: 28}, 20: {72: 21, 12303: 7, 275: 6, 3404: 8}, 23: {2974: 5, 25: 24}, 1: {14380: 2, 75: 25, 4260: 26, 1502: 4}, 14: {10251: 15, 83: 18}, 8: {9207: 28, 1: 9}, 22: {82: 20, 6321: 21, 46865: 6}, 4: {3252: 5, 1: 23, 788: 24}, 0: {4913: 1, 90: 27}, 13: {64: 14, 19488: 17, 266: 18, 1388: 16, 9307: 15}, 10: {68: 8}, 19: {436: 20, 78: 22, 34054: 6, 30724: 21}, 3: {75: 4}, 16: {9207: 28, 1: 9}, 12: {70: 13, 6743: 14}, 7: {586: 8, 85: 10}, 11: {68: 12, 15060: 16, 11188: 14, 791: 13}, 2: {68: 3, 301: 4}, 17: {68: 16}, 27: {92667: 4, 1: 1}, 6: {72: 7, 344: 10, 533: 8}, 5: {2724: 6, 77: 11, 28775: 13, 42224: 16, 79: 19, 5368: 22, 30487: 8, 966: 20, 811: 12}, 26: {1371: 3, 65: 2, 9779: 4}, 15: {586: 16, 85: 17}, 21: {10251: 7, 83: 6}, 24: {1: 5}, 25: {370: 2, 64: 26, 780: 4, 8229: 3}}

Parameters:

Name Type Description Default
response_format Union[dict, str, Type[Any]]

A Pydantic model, a dictionary, or a regular expression (as a string) that defines the expected response format

required
tokenizer Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast]

The model's tokenizer or the model name (as a string)

required
include_tool_call optional

Is the Pydantic model expecting a tool call or not.

False
tool_call_start optional

The expected tool call start.

'<tool_call>'
tool_call_end optional

The expected tool call end.

'</tool_call>'
whitespace_pattern optional

Pattern to use for JSON syntactic whitespace.

'[\\n\\t\\r ]*'

Returns:

Type Description
dict[int, dict[int, int]]

The deterministic finite automaton as a dictionary.

Raises:

Type Description
ValueError

An error occurs if the response format is not a Pydantic model, a dictionary, or a string that corresponds to the regular expression.

Source code in src/litelines/build_dfa.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def build_dfa(
    response_format: Union[dict, str, Type[Any]],
    tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast],
    include_tool_call: bool = False,
    tool_call_start: str = "<tool_call>",
    tool_call_end: str = "</tool_call>",
    whitespace_pattern: str = r"[\n\t\r ]*",
) -> dict[int, dict[int, int]]:
    """Build a deterministic finite automaton that fullfils the response format requirement

    Examples:
        >>> from typing import Literal
        >>> from pydantic import BaseModel, Field
        >>> from transformers import AutoTokenizer
        >>> from litelines import build_dfa
        >>>
        >>> model_id = "Qwen/Qwen3-0.6B"
        >>> tokenizer = AutoTokenizer.from_pretrained(model_id)
        >>> build_dfa("A|B", tokenizer)
        {0: {33: 1, 32: 1}}
        >>> build_dfa("A0|B0", tokenizer)
        {1: {15: 3}, 2: {15: 3}, 0: {33: 1, 32: 2}}
        >>>
        >>> class Sentiment(BaseModel):
        ...     "Correctly inferred `Sentiment` with all the required parameters with correct types."
        ...
        ...     label: Literal["positive", "negative"] = Field(
        ...         ..., description="Sentiment of the text"
        ...     )
        >>> build_dfa(Sentiment, tokenizer, whitespace_pattern="")
        {18: {72: 15, 344: 17, 533: 16}, 9: {92: 28}, 20: {72: 21, 12303: 7, 275: 6, 3404: 8}, 23: {2974: 5, 25: 24}, 1: {14380: 2, 75: 25, 4260: 26, 1502: 4}, 14: {10251: 15, 83: 18}, 8: {9207: 28, 1: 9}, 22: {82: 20, 6321: 21, 46865: 6}, 4: {3252: 5, 1: 23, 788: 24}, 0: {4913: 1, 90: 27}, 13: {64: 14, 19488: 17, 266: 18, 1388: 16, 9307: 15}, 10: {68: 8}, 19: {436: 20, 78: 22, 34054: 6, 30724: 21}, 3: {75: 4}, 16: {9207: 28, 1: 9}, 12: {70: 13, 6743: 14}, 7: {586: 8, 85: 10}, 11: {68: 12, 15060: 16, 11188: 14, 791: 13}, 2: {68: 3, 301: 4}, 17: {68: 16}, 27: {92667: 4, 1: 1}, 6: {72: 7, 344: 10, 533: 8}, 5: {2724: 6, 77: 11, 28775: 13, 42224: 16, 79: 19, 5368: 22, 30487: 8, 966: 20, 811: 12}, 26: {1371: 3, 65: 2, 9779: 4}, 15: {586: 16, 85: 17}, 21: {10251: 7, 83: 6}, 24: {1: 5}, 25: {370: 2, 64: 26, 780: 4, 8229: 3}}

    Args:
        response_format: A Pydantic model, a dictionary, or a regular expression (as a string) that defines the expected response format
        tokenizer: The model's tokenizer or the model name (as a string)
        include_tool_call (optional): Is the Pydantic model expecting a tool call or not.
        tool_call_start (optional): The expected tool call start.
        tool_call_end (optional): The expected tool call end.
        whitespace_pattern (optional): Pattern to use for JSON syntactic whitespace.

    Returns:
        The deterministic finite automaton as a dictionary.

    Raises:
        ValueError: An error occurs if the response format is not a Pydantic model, a dictionary, or a string that corresponds to the regular expression.
    """
    if isinstance(response_format, str):
        if is_valid_json(response_format):
            regex_str = build_regex(
                response_format,
                include_tool_call=include_tool_call,
                whitespace_pattern=whitespace_pattern,
            )
        elif is_valid_regex(response_format):
            regex_str = response_format
        else:
            invalid_schema_error(response_format)
    elif isinstance(response_format, dict) or hasattr(
        response_format, "model_json_schema"
    ):
        regex_str = build_regex(
            response_format,
            include_tool_call=include_tool_call,
            tool_call_start=tool_call_start,
            tool_call_end=tool_call_end,
            whitespace_pattern=whitespace_pattern,
        )
    else:
        invalid_schema_error(response_format)

    if isinstance(tokenizer, str):
        model_name = tokenizer
    elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
        model_name = getattr(tokenizer, "name_or_path", None)
        if model_name is None:
            raise ValueError(
                "Could not determine model name from tokenizer. "
                + "You can pass it directly to the build_dfa function."
            )
    else:
        raise ValueError(
            "The tokenizer must be either "
            + "a PreTrainedTokenizer, a PreTrainedTokenizerFast "
            + "or a string that corresponds to the model name."
        )

    vocabulary = Vocabulary.from_pretrained(model_name)
    index = Index(regex_str, vocabulary)
    dfa = get_dfa(index)
    return dfa

draw_dfa

Create a graphical representation of a Deterministic Finite Automaton (DFA) using Graphviz DOT language.

The function visualizes the DFA with:

  • states as circles (double circles for final states)
  • directed edges showing transitions between states
  • edge labels containing tables of token IDs and their corresponding text
  • optional red highlighting for edges in the provided trajectory

Examples:

>>> from typing import Literal
>>> from pydantic import BaseModel, Field
>>> from transformers import AutoTokenizer
>>> from litelines import build_dfa
>>>
>>> model_id = "Qwen/Qwen3-0.6B"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> draw_dfa("A|B", tokenizer, render=False)
#
>>> draw_dfa("A0|B0", tokenizer, render=False)
#
>>>
>>> class Sentiment(BaseModel):
...     "Correctly inferred `Sentiment` with all the required parameters with correct types."
...
...     label: Literal["positive", "negative"] = Field(
...         ..., description="Sentiment of the text"
...     )
>>> draw_dfa(Sentiment, tokenizer, whitespace_pattern="")
#

Parameters:

Name Type Description Default
dfa Union[dict[int, dict[int, int]], str, Type[Any]]

The DFA representation, which can be either: A dictionary mapping states to their transitions A JSON schema string A Pydantic schema

required
tokenizer Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

The tokenizer used to decode token IDs into readable text

required
trajectory list

Optional list of tokens representing a path through the DFA

[]
include_tool_call optional

Is the Pydantic model expecting a tool call or not.

False
tool_call_start optional

The expected tool call start.

'<tool_call>'
tool_call_end optional

The expected tool call end.

'</tool_call>'
whitespace_pattern optional

Pattern to use for JSON syntactic whitespace.

'[\\n\\t ]*'
max_labels_per_edge optional

Maximum number of labels to show per edge

3
remove_outer_whitespace optional

Whether to strip whitespace from token labels in the table.

True
render optional

Whether to return a rendered Graphviz Source object or raw DOT string

True

Returns:

Type Description
str | None

A Graphviz Source object if render=True, otherwise the DOT language string

Source code in src/litelines/draw_dfa.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def draw_dfa(
    dfa: Union[dict[int, dict[int, int]], str, Type[Any]],
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    trajectory: list = [],
    include_tool_call: bool = False,
    tool_call_start: str = "<tool_call>",
    tool_call_end: str = "</tool_call>",
    whitespace_pattern: str = r"[\n\t ]*",
    max_labels_per_edge: int = 3,
    remove_outer_whitespace: bool = True,
    ratio: Optional[Union[float, str]] = None,
    size: Optional[Union[Tuple[float, float], str]] = None,
    render: bool = True,
) -> str | None:
    """Create a graphical representation of a Deterministic Finite Automaton (DFA) using Graphviz DOT language.

    The function visualizes the DFA with:

    - states as circles (double circles for final states)
    - directed edges showing transitions between states
    - edge labels containing tables of token IDs and their corresponding text
    - optional red highlighting for edges in the provided trajectory

    Examples:
        >>> from typing import Literal
        >>> from pydantic import BaseModel, Field
        >>> from transformers import AutoTokenizer
        >>> from litelines import build_dfa
        >>>
        >>> model_id = "Qwen/Qwen3-0.6B"
        >>> tokenizer = AutoTokenizer.from_pretrained(model_id)
        >>> draw_dfa("A|B", tokenizer, render=False)
        #
        >>> draw_dfa("A0|B0", tokenizer, render=False)
        #
        >>>
        >>> class Sentiment(BaseModel):
        ...     "Correctly inferred `Sentiment` with all the required parameters with correct types."
        ...
        ...     label: Literal["positive", "negative"] = Field(
        ...         ..., description="Sentiment of the text"
        ...     )
        >>> draw_dfa(Sentiment, tokenizer, whitespace_pattern="")
        #

    Args:
        dfa: The DFA representation, which can be either:
            A dictionary mapping states to their transitions
            A JSON schema string
            A Pydantic schema
        tokenizer: The tokenizer used to decode token IDs into readable text
        trajectory: Optional list of tokens representing a path through the DFA
        include_tool_call (optional): Is the Pydantic model expecting a tool call or not.
        tool_call_start (optional): The expected tool call start.
        tool_call_end (optional): The expected tool call end.
        whitespace_pattern (optional): Pattern to use for JSON syntactic whitespace.
        max_labels_per_edge (optional): Maximum number of labels to show per edge
        remove_outer_whitespace (optional): Whether to strip whitespace from token labels in the table.
        render (optional): Whether to return a rendered Graphviz Source object or raw DOT string

    Returns:
        A Graphviz Source object if render=True, otherwise the DOT language string
    """

    if isinstance(dfa, dict) and all(
        isinstance(k, int)
        and isinstance(v, dict)
        and all(isinstance(k2, int) and isinstance(v2, int) for k2, v2 in v.items())
        for k, v in dfa.items()
    ):
        regex = ""
        dfa = dfa
    elif isinstance(dfa, str):
        if is_valid_json(dfa):
            regex = build_regex(
                dfa,
                include_tool_call=include_tool_call,
                whitespace_pattern=whitespace_pattern,
            )
            dfa = build_dfa(
                dfa,
                tokenizer=tokenizer,
                include_tool_call=include_tool_call,
                whitespace_pattern=whitespace_pattern,
            )
        elif is_valid_regex(dfa):
            regex = dfa
            dfa = build_dfa(
                dfa,
                tokenizer=tokenizer,
                include_tool_call=include_tool_call,
                whitespace_pattern=whitespace_pattern,
            )
        else:
            invalid_schema_error(dfa)
    elif hasattr(dfa, "model_json_schema"):
        regex = build_regex(
            dfa,
            include_tool_call=include_tool_call,
            whitespace_pattern=whitespace_pattern,
        )
        dfa = build_dfa(
            dfa,
            tokenizer=tokenizer,
            include_tool_call=include_tool_call,
            whitespace_pattern=whitespace_pattern,
        )
    else:
        invalid_schema_error(dfa)

    if trajectory != []:
        state_trajectory = from_token_trajectory_to_state_trajectory(trajectory, dfa)

    states = range(len(dfa) + 1)
    final_states = {state for state in states if state not in list(dfa.keys())}
    graph_str = "// Allowed Transitions Graph\ndigraph {"
    if regex != "":
        graph_str += f'\n\tgraph [label="Allowed Paths\nRegular expression: {build_escaped_title(regex)}",labelloc="t",labeljust="l"]'
    else:
        graph_str += '\n\tgraph [label="Allowed Paths",labelloc="t",labeljust="l"]'
    graph_str += f'\n\trankdir=LR;size="{size}";ratio={ratio};'
    # Add states to the graph
    for state in states:
        if state in final_states:
            # Shape the final states with double circle
            graph_str += f'\n\t{state} [label="{state}" shape=doublecircle]'
        else:
            # Shape the other states with a circle
            graph_str += f'\n\t{state} [label="{state}" shape=circle]'
    # Add empty fake node for initial arrow
    graph_str += '\n\tnode [shape=none]\n\t"" [label=""]\n\t"" -> 0'
    # Put together all edges from state to next_state to the graph
    all_edges = defaultdict(list)
    for state, transitions in dfa.items():
        for key, next_state in transitions.items():
            all_edges[(state, next_state)].append(key)
    # Add edges to the graph
    for state in states:
        for next_state in states:
            if all_edges[(state, next_state)] != []:
                table_str = create_table(
                    all_edges[(state, next_state)],
                    tokenizer,
                    max_labels_per_edge=3,
                    remove_outer_whitespace=True,
                )
                if (
                    trajectory != []
                    and state_trajectory != {}
                    and state in state_trajectory.keys()
                    and next_state in state_trajectory[state]
                ):
                    graph_str += f"\n\t{state} -> {next_state} [label=<{table_str}> color=red penwidth=3.0]"
                else:
                    graph_str += f"\n\t{state} -> {next_state} [label=<{table_str}>]"
    graph_str += "\n}\n"
    return display_dot_graph(dot=graph_str, render=render)

Schema_Processor

Bases: LogitsProcessor

Build the Logits Processor that enforces the response format

Examples:

Parameters:

Name Type Description Default
input_ids

Token IDs of shape (batch_size, sequence_length)

required
scores

Logits of shape (batch_size, vocab_size)

required

Returns:

Type Description

The logits processor that enforces the response format

Source code in src/litelines/transformers/schemaprocessor.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class SchemaProcessor(LogitsProcessor):
    """Build the Logits Processor that enforces the response format

    Examples:

    Args:
        input_ids: Token IDs of shape (batch_size, sequence_length)
        scores: Logits of shape (batch_size, vocab_size)

    Returns:
        The logits processor that enforces the response format
    """

    def __init__(
        self,
        response_format: Union[str, dict[int, dict[int, int]], Type[Any]],
        tokenizer: PreTrainedTokenizer,
        include_tool_call: bool = False,
        tool_call_start: str = "<tool_call>",
        tool_call_end: str = "</tool_call>",
        allow_preamble: bool = False,
        whitespace_pattern: str = r"[\n\t\r ]*",
        verbose: bool = False,
        max_same_state_visit_count: int = 5,
    ) -> None:
        self.response_format = response_format
        self.tokenizer = tokenizer
        self.include_tool_call = include_tool_call
        self.tool_call_start = tool_call_start
        self.tool_call_end = tool_call_end
        self.allow_preamble = allow_preamble
        self.whitespace_pattern = whitespace_pattern
        self.dfa = None
        self.verbose = verbose
        self.max_same_state_visit_count = max_same_state_visit_count
        self.same_state_visit_count = 0
        self.current_state = None
        self.previous_state = None
        self.final_states = None
        self.selected_token = None
        self.trajectory = []
        self.previous_input_ids = None
        self.trigger_token_ids = []
        self.triggered = None

    def __build_dfa(self):
        self.dfa = build_dfa(
            self.response_format,
            self.tokenizer,
            include_tool_call=self.include_tool_call,
            tool_call_start=self.tool_call_start,
            tool_call_end=self.tool_call_end,
            whitespace_pattern=self.whitespace_pattern,
        )

    def __create_dfa(self):
        if isinstance(self.response_format, dict) and all(
            isinstance(k, int)
            and isinstance(v, dict)
            and all(isinstance(k2, int) and isinstance(v2, int) for k2, v2 in v.items())
            for k, v in (self.response_format).items()
        ):
            self.dfa = self.response_format
        elif isinstance(self.response_format, str):
            self.__build_dfa()
        elif hasattr(self.response_format, "model_json_schema"):
            self.__build_dfa()
        else:
            raise ValueError(
                f"Cannot parse schema {self.response_format}. The schema must be either "
                + "a Pydantic model, a dict[int, dict[int, int]] or a string that contains the JSON "
                + "schema specification"
            )

    def show_graph(self):
        if self.trajectory == []:  # first time
            self.__create_dfa()
        return draw_dfa(
            self.response_format,
            self.tokenizer,
            self.trajectory,
            self.include_tool_call,
            self.tool_call_start,
            self.tool_call_end,
            self.whitespace_pattern,
        )

    def reset_state(self):
        """Reset the processor to its initial state"""
        self.current_state = 0
        self.final_states = None
        self.selected_token = None
        self.trajectory = []

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        scores_processed = scores.clone()
        token_chosen_id = torch.argmax(scores_processed).item()

        if self.previous_input_ids is not None:
            # Check if we're continuing from the previous sequence
            if not torch.equal(input_ids[:, :-1], self.previous_input_ids):
                # If the history doesn't match, reset the state
                self.reset_state()

        if self.final_states is None:  # first time
            if self.dfa is None:
                self.__create_dfa()
            states = range(len(self.dfa) + 1)
            self.final_states = {
                state for state in states if state not in list((self.dfa).keys())
            }
            if self.verbose:
                print(f"states: {states}")
                print(f"final states: {self.final_states}")
            self.previous_input_ids = input_ids.clone()
            if not self.allow_preamble:
                self.current_state = 0  # dfa active
                self.triggered = True
            else:
                self.current_state = -1  # inactive
                self.triggered = False
                # add eos to triggers
                self.trigger_token_ids = [
                    self.tokenizer.eos_token_id,
                    self.tokenizer.pad_token_id,
                ]
                if self.include_tool_call:  # it should be a tool call
                    if (
                        len(
                            self.tokenizer.encode(
                                self.tool_call_start, add_special_tokens=False
                            )
                        )
                        > 1
                    ):
                        raise ValueError(
                            f"{self.tool_call_start} is not a valid token."
                        )
                    self.trigger_token_ids.append(
                        self.tokenizer.encode(
                            self.tool_call_start, add_special_tokens=False
                        )[0]
                    )
                    # not the best solution since it excludes '<' in the preamble
                    tokens_containing_open_tool_call = [
                        token_id
                        for token_id in range(self.tokenizer.vocab_size)
                        if "<" in self.tokenizer.decode(token_id)
                    ]
                    self.trigger_token_ids += tokens_containing_open_tool_call
                else:  # it should be json
                    tokens_containing_open_curly_bracket = [
                        token_id
                        for token_id in range(self.tokenizer.vocab_size)
                        if "{" in self.tokenizer.decode(token_id)
                    ]
                    self.trigger_token_ids += tokens_containing_open_curly_bracket

        else:  # not the first time
            self.selected_token = input_ids[:, -1].item()
            if self.current_state != -1:
                self.trajectory.append(self.selected_token)
            if self.verbose:
                print(
                    f"\x1b[32mselected token: {self.selected_token}: {repr(self.tokenizer.decode([self.selected_token]))}\x1b[0m"
                )
            if self.verbose and self.current_state != -1:
                print(f"mapping: {self.dfa[self.current_state]}")

        # activate it if it is triggered
        if (
            self.current_state == -1 and token_chosen_id in self.trigger_token_ids
        ):  # if dfa is inactive
            if self.verbose:
                print(
                    f"\x1b[31mtrigger token: {token_chosen_id}: {self.tokenizer.decode([token_chosen_id])}\x1b[0m"
                )
            self.triggered = True
            self.current_state = 0

        if self.current_state != -1:  # if dfa is active
            if self.triggered:
                self.current_state = 0
                self.triggered = False
            else:
                self.previous_state = self.current_state
                self.current_state = self.dfa[self.current_state][self.selected_token]
                if (
                    self.previous_state == self.current_state
                    and re.fullmatch(
                        self.whitespace_pattern,
                        self.tokenizer.decode([self.selected_token]),
                    )
                    is not None
                ):
                    self.same_state_visit_count += 1
                else:
                    self.same_state_visit_count = 0

        self.previous_input_ids = input_ids.clone()

        vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)

        if self.verbose:
            print(f"\x1b[34mcurrent state: {self.current_state}\x1b[0m")
            print(
                f"\x1b[33msame state visit count: {self.same_state_visit_count}\x1b[0m"
            )

        if self.current_state == -1:
            # print(self.trigger_token_ids)
            forbidden_tokens = torch.tensor(
                self.trigger_token_ids, device=scores.device
            )
            forbidden_tokens_mask = torch.isin(vocab_tensor, forbidden_tokens)
        else:  # if dfa is active
            if self.current_state in self.final_states:
                allowed_tokens = [self.tokenizer.eos_token_id]
            else:
                if self.same_state_visit_count < self.max_same_state_visit_count:
                    allowed_tokens = list(self.dfa[self.current_state].keys())
                else:
                    # Remove tokens that send you to the same current state
                    if self.verbose:
                        print(
                            f"\x1b[31mmaximum same state visit count reached for state {self.current_state}\x1b[0m"
                        )
                    mapping = self.dfa[self.current_state]
                    allowed_tokens = [
                        key
                        for key, value in mapping.items()
                        if value != self.current_state
                    ]
            allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
            forbidden_tokens_mask = ~torch.isin(vocab_tensor, allowed_tokens)

        scores_processed = torch.where(forbidden_tokens_mask, -torch.inf, scores)
        if self.verbose:
            print(
                f"\x1b[35mwill be chosen: {torch.argmax(scores_processed).item()}\x1b[0m"
            )

        return scores_processed

reset_state()

Reset the processor to its initial state

Source code in src/litelines/transformers/schemaprocessor.py
105
106
107
108
109
110
def reset_state(self):
    """Reset the processor to its initial state"""
    self.current_state = 0
    self.final_states = None
    self.selected_token = None
    self.trajectory = []