This article will introduce how to implement streaming output for model inference in the
transformers
module.
The transformers
module provides a built-in Streaming method for streaming output during model inference. Additionally, we can use model deployment frameworks such as vLLM and TGI to better support streaming output for model inference.
Below, we will detail how to achieve streaming output for model inference in the transformers
module.
Streaming Output
-
Use the built-in
TextStreamer
from thetransformers
module to perform streaming output in the terminal.
Example code:
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
message = [{"role": "user", "content": "How many subway lines are there in Shenyang?"}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
print(conversion)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextStreamer(tokenizer)
model.generate(**encoding, max_new_tokens=500, temperature=0.2, do_sample=True, streamer=streamer, pad_token_id=tokenizer.eos_token_id)
Output result:
As of 2022, Shenyang has opened and operated 4 subway lines: Line 1, Line 2, Line 3, and Line 9. Future construction plans for the Shenyang subway are still ongoing, with expectations for further extensions and expansions of the lines.
-
Use the built-in
TextIteratorStreamer
from thetransformers
module to customize streaming output.
Example code:
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
message = [{"role": "user", "content": "How many subway lines are there in Shenyang?"}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
print(conversion)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(encoding, streamer=streamer, max_new_tokens=100, do_sample=True, temperature=0.2)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
output = new_text.replace(conversion, '')
if output:
print(output)
Output result:
As of 2022, Shenyang has opened and operated 4 subway lines, namely Line 1, Line 2, Line 3, and Line 9. In the future, there are still several lines under construction in the planning, such as Line 6 and Line 8.
-
Use Gradio to implement web-based streaming output
# -*- coding: utf-8 -*-
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def question_answer(query):
message = [{"role": "user", "content": query}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(encoding, streamer=streamer, max_new_tokens=1000, do_sample=True, temperature=0.2)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generate_text = ''
for new_text in streamer:
output = new_text.replace(conversion, '')
if output:
generate_text += output
yield generate_text
demo = gr.Interface(
fn=question_answer,
inputs=gr.Textbox(lines=3, placeholder="your question...", label="Question"),
outputs="text",
)
demo.launch(server_name="0.0.0.0", server_port=50079, share=True)
The effect is as follows:
References
-
Streamers in transformers: https://huggingface.co/docs/transformers/v4.39.3/en/internal/generation_utils#transformers.TextStreamer