Streaming Output for Model Inference in Transformers

This article will introduce how to implement streaming output for model inference in the transformers module.

The transformers module provides a built-in Streaming method for streaming output during model inference. Additionally, we can use model deployment frameworks such as vLLM and TGI to better support streaming output for model inference.

Below, we will detail how to achieve streaming output for model inference in the transformers module.

Streaming Output

  • Use the built-in TextStreamer from the transformers module to perform streaming output in the terminal.

Example code:

from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
message = [{"role": "user", "content": "How many subway lines are there in Shenyang?"}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
print(conversion)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextStreamer(tokenizer)
model.generate(**encoding, max_new_tokens=500, temperature=0.2, do_sample=True, streamer=streamer, pad_token_id=tokenizer.eos_token_id)

Output result:

As of 2022, Shenyang has opened and operated 4 subway lines: Line 1, Line 2, Line 3, and Line 9. Future construction plans for the Shenyang subway are still ongoing, with expectations for further extensions and expansions of the lines.
  • Use the built-in TextIteratorStreamer from the transformers module to customize streaming output.

Example code:

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
message = [{"role": "user", "content": "How many subway lines are there in Shenyang?"}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
print(conversion)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(encoding, streamer=streamer, max_new_tokens=100, do_sample=True, temperature=0.2)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

generated_text = ""
for new_text in streamer:
    output = new_text.replace(conversion, '')
    if output:
        print(output)

Output result:

As of 2022, Shenyang has opened and operated 4 subway lines, namely Line 1, Line 2, Line 3, and Line 9. In the future, there are still several lines under construction in the planning, such as Line 6 and Line 8.
  • Use Gradio to implement web-based streaming output

# -*- coding: utf-8 -*-
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

def question_answer(query):
    message = [{"role": "user", "content": query}]
    conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
    encoding = tokenizer(conversion, return_tensors="pt")
    streamer = TextIteratorStreamer(tokenizer)
    generation_kwargs = dict(encoding, streamer=streamer, max_new_tokens=1000, do_sample=True, temperature=0.2)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    generate_text = ''
    for new_text in streamer:
        output = new_text.replace(conversion, '')
        if output:
            generate_text += output
            yield generate_text


demo = gr.Interface(
    fn=question_answer,
    inputs=gr.Textbox(lines=3, placeholder="your question...", label="Question"),
    outputs="text",
)

demo.launch(server_name="0.0.0.0", server_port=50079, share=True)

The effect is as follows:

References

  1. Streamers in transformers: https://huggingface.co/docs/transformers/v4.39.3/en/internal/generation_utils#transformers.TextStreamer

Leave a Comment