forked from stepfun-ai/Step-Audio2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstepaudio2.py
More file actions
204 lines (181 loc) · 9 KB
/
stepaudio2.py
File metadata and controls
204 lines (181 loc) · 9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels
class StepAudio2Base:
def __init__(self, model_path: str):
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right")
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
self.eos_token_id = self.llm_tokenizer.eos_token_id
def __call__(self, messages: list, **kwargs):
messages, mels = self.apply_chat_template(messages)
# Tokenize prompts
prompt_ids = []
for msg in messages:
if isinstance(msg, str):
prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"])
elif isinstance(msg, list):
prompt_ids.append(torch.tensor([msg], dtype=torch.int32))
else:
raise ValueError(f"Unsupported content type: {type(msg)}")
prompt_ids = torch.cat(prompt_ids, dim=-1).cuda()
attention_mask = torch.ones_like(prompt_ids)
#mels = None if len(mels) == 0 else torch.stack(mels).cuda()
#mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda')
if len(mels)==0:
mels = None
mel_lengths = None
else:
mels, mel_lengths = padding_mels(mels)
mels = mels.cuda()
mel_lengths = mel_lengths.cuda()
generate_inputs = {
"input_ids": prompt_ids,
"wavs": mels,
"wav_lens": mel_lengths,
"attention_mask":attention_mask
}
generation_config = dict(max_new_tokens=2048,
pad_token_id=self.llm_tokenizer.pad_token_id,
eos_token_id=self.eos_token_id,
)
generation_config.update(kwargs)
generation_config = GenerationConfig(**generation_config)
outputs = self.llm.generate(**generate_inputs, generation_config=generation_config)
output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist()
output_text_tokens = [i for i in output_token_ids if i < 151688]
output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695]
output_text = self.llm_tokenizer.decode(output_text_tokens)
return output_token_ids, output_text, output_audio_tokens
def apply_chat_template(self, messages: list):
results = []
mels = []
for msg in messages:
content = msg
if isinstance(content, str):
text_with_audio = content
results.append(text_with_audio)
elif isinstance(content, dict):
if content["type"] == "text":
results.append(f"{content['text']}")
elif content["type"] == "audio":
audio = load_audio(content['audio'])
for i in range(0, audio.shape[0], 16000 * 25):
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
mels.append(mel)
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
results.append(f"<audio_start>{audio_tokens}<audio_end>")
elif content["type"] == "token":
results.append(content["token"])
else:
raise ValueError(f"Unsupported content type: {type(content)}")
# print(results)
return results, mels
class StepAudio2(StepAudio2Base):
def __init__(self, model_path: str):
super().__init__(model_path)
self.llm_tokenizer.eos_token = "<|EOT|>"
self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
def apply_chat_template(self, messages: list):
results = []
mels = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "user":
role = "human"
if isinstance(content, str):
text_with_audio = f"<|BOT|>{role}\n{content}"
text_with_audio += '<|EOT|>' if msg.get('eot', True) else ''
results.append(text_with_audio)
elif isinstance(content, list):
results.append(f"<|BOT|>{role}\n")
for item in content:
if item["type"] == "text":
results.append(f"{item['text']}")
elif item["type"] == "audio":
audio = load_audio(item['audio'])
for i in range(0, audio.shape[0], 16000 * 25):
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
mels.append(mel)
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
results.append(f"<audio_start>{audio_tokens}<audio_end>")
elif item["type"] == "token":
results.append(item["token"])
if msg.get('eot', True):
results.append('<|EOT|>')
elif content is None:
results.append(f"<|BOT|>{role}\n")
else:
raise ValueError(f"Unsupported content type: {type(content)}")
# print(results)
return results, mels
if __name__ == '__main__':
from token2wav import Token2wav
model = StepAudio2('Step-Audio-2-mini')
token2wav = Token2wav('Step-Audio-2-mini/token2wav')
# Text-to-text conversation
print()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "human", "content": "Give me a brief introduction to the Great Wall."},
{"role": "assistant", "content": None}
]
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
print(text)
# Text-to-speech conversation
print()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "human", "content": "Give me a brief introduction to the Great Wall."},
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
]
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
print(text)
print(tokens)
audio = token2wav(audio, prompt_wav='assets/default_male.wav')
with open('output-male.wav', 'wb') as f:
f.write(audio)
# Speech-to-text conversation
print()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
{"role": "assistant", "content": None}
]
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
print(text)
# Speech-to-speech conversation
print()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
]
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
print(text)
print(tokens)
audio = token2wav(audio, prompt_wav='assets/default_female.wav')
with open('output-female.wav', 'wb') as f:
f.write(audio)
# Multi-turn conversation
print()
messages.pop(-1)
messages += [
{"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"},
{"type": "token", "token": tokens}]},
{"role": "human", "content": "Now write a 4-line poem about it."},
{"role": "assistant", "content": None}
]
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
print(text)
# Multi-modal inputs
print()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."},
{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
{"role": "assistant", "content": None}
]
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
print(text)