667 lines
17 KiB
Markdown
667 lines
17 KiB
Markdown
|
|
# AudioCraft Advanced Usage Guide
|
||
|
|
|
||
|
|
## Fine-tuning MusicGen
|
||
|
|
|
||
|
|
### Custom dataset preparation
|
||
|
|
|
||
|
|
```python
|
||
|
|
import os
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
import torchaudio
|
||
|
|
|
||
|
|
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||
|
|
"""
|
||
|
|
Prepare dataset for MusicGen fine-tuning.
|
||
|
|
|
||
|
|
Directory structure:
|
||
|
|
output_dir/
|
||
|
|
├── audio/
|
||
|
|
│ ├── 0001.wav
|
||
|
|
│ ├── 0002.wav
|
||
|
|
│ └── ...
|
||
|
|
└── metadata.json
|
||
|
|
"""
|
||
|
|
output_dir = Path(output_dir)
|
||
|
|
audio_output = output_dir / "audio"
|
||
|
|
audio_output.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Load metadata (format: {"path": "...", "description": "..."})
|
||
|
|
with open(metadata_file) as f:
|
||
|
|
metadata = json.load(f)
|
||
|
|
|
||
|
|
processed = []
|
||
|
|
|
||
|
|
for idx, item in enumerate(metadata):
|
||
|
|
audio_path = Path(audio_dir) / item["path"]
|
||
|
|
|
||
|
|
# Load and resample to 32kHz
|
||
|
|
wav, sr = torchaudio.load(str(audio_path))
|
||
|
|
if sr != 32000:
|
||
|
|
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||
|
|
wav = resampler(wav)
|
||
|
|
|
||
|
|
# Convert to mono if stereo
|
||
|
|
if wav.shape[0] > 1:
|
||
|
|
wav = wav.mean(dim=0, keepdim=True)
|
||
|
|
|
||
|
|
# Save processed audio
|
||
|
|
output_path = audio_output / f"{idx:04d}.wav"
|
||
|
|
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||
|
|
|
||
|
|
processed.append({
|
||
|
|
"path": str(output_path.relative_to(output_dir)),
|
||
|
|
"description": item["description"],
|
||
|
|
"duration": wav.shape[1] / 32000
|
||
|
|
})
|
||
|
|
|
||
|
|
# Save processed metadata
|
||
|
|
with open(output_dir / "metadata.json", "w") as f:
|
||
|
|
json.dump(processed, f, indent=2)
|
||
|
|
|
||
|
|
print(f"Processed {len(processed)} samples")
|
||
|
|
return processed
|
||
|
|
```
|
||
|
|
|
||
|
|
### Fine-tuning with dora
|
||
|
|
|
||
|
|
```bash
|
||
|
|
# AudioCraft uses dora for experiment management
|
||
|
|
# Install dora
|
||
|
|
pip install dora-search
|
||
|
|
|
||
|
|
# Clone AudioCraft
|
||
|
|
git clone https://github.com/facebookresearch/audiocraft.git
|
||
|
|
cd audiocraft
|
||
|
|
|
||
|
|
# Create config for fine-tuning
|
||
|
|
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||
|
|
defaults:
|
||
|
|
- musicgen/musicgen_base
|
||
|
|
- /model: lm/musicgen_lm
|
||
|
|
- /conditioner: cond_base
|
||
|
|
|
||
|
|
solver: musicgen
|
||
|
|
autocast: true
|
||
|
|
autocast_dtype: float16
|
||
|
|
|
||
|
|
optim:
|
||
|
|
epochs: 100
|
||
|
|
batch_size: 4
|
||
|
|
lr: 1e-4
|
||
|
|
ema: 0.999
|
||
|
|
optimizer: adamw
|
||
|
|
|
||
|
|
dataset:
|
||
|
|
batch_size: 4
|
||
|
|
num_workers: 4
|
||
|
|
train:
|
||
|
|
- dset: your_dataset
|
||
|
|
root: /path/to/dataset
|
||
|
|
valid:
|
||
|
|
- dset: your_dataset
|
||
|
|
root: /path/to/dataset
|
||
|
|
|
||
|
|
checkpoint:
|
||
|
|
save_every: 10
|
||
|
|
keep_every_states: null
|
||
|
|
EOF
|
||
|
|
|
||
|
|
# Run fine-tuning
|
||
|
|
dora run solver=musicgen/finetune
|
||
|
|
```
|
||
|
|
|
||
|
|
### LoRA fine-tuning
|
||
|
|
|
||
|
|
```python
|
||
|
|
from peft import LoraConfig, get_peft_model
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
import torch
|
||
|
|
|
||
|
|
# Load base model
|
||
|
|
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||
|
|
|
||
|
|
# Get the language model component
|
||
|
|
lm = model.lm
|
||
|
|
|
||
|
|
# Configure LoRA
|
||
|
|
lora_config = LoraConfig(
|
||
|
|
r=8,
|
||
|
|
lora_alpha=16,
|
||
|
|
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||
|
|
lora_dropout=0.05,
|
||
|
|
bias="none"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Apply LoRA
|
||
|
|
lm = get_peft_model(lm, lora_config)
|
||
|
|
lm.print_trainable_parameters()
|
||
|
|
```
|
||
|
|
|
||
|
|
## Multi-GPU Training
|
||
|
|
|
||
|
|
### DataParallel
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
|
||
|
|
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||
|
|
|
||
|
|
# Wrap LM with DataParallel
|
||
|
|
if torch.cuda.device_count() > 1:
|
||
|
|
model.lm = nn.DataParallel(model.lm)
|
||
|
|
|
||
|
|
model.to("cuda")
|
||
|
|
```
|
||
|
|
|
||
|
|
### DistributedDataParallel
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch.distributed as dist
|
||
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
|
|
|
||
|
|
def setup(rank, world_size):
|
||
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||
|
|
torch.cuda.set_device(rank)
|
||
|
|
|
||
|
|
def train(rank, world_size):
|
||
|
|
setup(rank, world_size)
|
||
|
|
|
||
|
|
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||
|
|
model.lm = model.lm.to(rank)
|
||
|
|
model.lm = DDP(model.lm, device_ids=[rank])
|
||
|
|
|
||
|
|
# Training loop
|
||
|
|
# ...
|
||
|
|
|
||
|
|
dist.destroy_process_group()
|
||
|
|
```
|
||
|
|
|
||
|
|
## Custom Conditioning
|
||
|
|
|
||
|
|
### Adding new conditioners
|
||
|
|
|
||
|
|
```python
|
||
|
|
from audiocraft.modules.conditioners import BaseConditioner
|
||
|
|
import torch
|
||
|
|
|
||
|
|
class CustomConditioner(BaseConditioner):
|
||
|
|
"""Custom conditioner for additional control signals."""
|
||
|
|
|
||
|
|
def __init__(self, dim, output_dim):
|
||
|
|
super().__init__(dim, output_dim)
|
||
|
|
self.embed = torch.nn.Linear(dim, output_dim)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.embed(x)
|
||
|
|
|
||
|
|
def tokenize(self, x):
|
||
|
|
# Tokenize input for conditioning
|
||
|
|
return x
|
||
|
|
|
||
|
|
# Use with MusicGen
|
||
|
|
from audiocraft.models.builders import get_lm_model
|
||
|
|
|
||
|
|
# Modify model config to include custom conditioner
|
||
|
|
# This requires editing the model configuration
|
||
|
|
```
|
||
|
|
|
||
|
|
### Melody conditioning internals
|
||
|
|
|
||
|
|
```python
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||
|
|
import torch
|
||
|
|
|
||
|
|
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||
|
|
|
||
|
|
# Access chroma extractor
|
||
|
|
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||
|
|
|
||
|
|
# Manual chroma extraction
|
||
|
|
def extract_chroma(audio, sr):
|
||
|
|
"""Extract chroma features from audio."""
|
||
|
|
import librosa
|
||
|
|
|
||
|
|
# Compute chroma
|
||
|
|
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||
|
|
|
||
|
|
return torch.from_numpy(chroma).float()
|
||
|
|
|
||
|
|
# Use extracted chroma for conditioning
|
||
|
|
chroma = extract_chroma(melody_audio, sample_rate)
|
||
|
|
```
|
||
|
|
|
||
|
|
## EnCodec Deep Dive
|
||
|
|
|
||
|
|
### Custom compression settings
|
||
|
|
|
||
|
|
```python
|
||
|
|
from audiocraft.models import CompressionModel
|
||
|
|
import torch
|
||
|
|
|
||
|
|
# Load EnCodec
|
||
|
|
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||
|
|
|
||
|
|
# Access codec parameters
|
||
|
|
print(f"Sample rate: {encodec.sample_rate}")
|
||
|
|
print(f"Channels: {encodec.channels}")
|
||
|
|
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||
|
|
print(f"Num codebooks: {encodec.num_codebooks}")
|
||
|
|
print(f"Frame rate: {encodec.frame_rate}")
|
||
|
|
|
||
|
|
# Encode with specific bandwidth
|
||
|
|
# Lower bandwidth = more compression, lower quality
|
||
|
|
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||
|
|
|
||
|
|
audio = torch.randn(1, 1, 32000) # 1 second
|
||
|
|
encoded = encodec.encode(audio)
|
||
|
|
decoded = encodec.decode(encoded[0])
|
||
|
|
```
|
||
|
|
|
||
|
|
### Streaming encoding
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch
|
||
|
|
from audiocraft.models import CompressionModel
|
||
|
|
|
||
|
|
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||
|
|
|
||
|
|
def encode_streaming(audio_stream, chunk_size=32000):
|
||
|
|
"""Encode audio in streaming fashion."""
|
||
|
|
all_codes = []
|
||
|
|
|
||
|
|
for chunk in audio_stream:
|
||
|
|
# Ensure chunk is right shape
|
||
|
|
if chunk.dim() == 1:
|
||
|
|
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
codes = encodec.encode(chunk)[0]
|
||
|
|
all_codes.append(codes)
|
||
|
|
|
||
|
|
return torch.cat(all_codes, dim=-1)
|
||
|
|
|
||
|
|
def decode_streaming(codes_stream, output_stream):
|
||
|
|
"""Decode codes in streaming fashion."""
|
||
|
|
for codes in codes_stream:
|
||
|
|
with torch.no_grad():
|
||
|
|
audio = encodec.decode(codes)
|
||
|
|
output_stream.write(audio.cpu().numpy())
|
||
|
|
```
|
||
|
|
|
||
|
|
## MultiBand Diffusion
|
||
|
|
|
||
|
|
### Using MBD for enhanced quality
|
||
|
|
|
||
|
|
```python
|
||
|
|
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||
|
|
|
||
|
|
# Load MusicGen
|
||
|
|
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||
|
|
|
||
|
|
# Load MultiBand Diffusion
|
||
|
|
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||
|
|
|
||
|
|
model.set_generation_params(duration=10)
|
||
|
|
|
||
|
|
# Generate with standard decoder
|
||
|
|
descriptions = ["epic orchestral music"]
|
||
|
|
wav_standard = model.generate(descriptions)
|
||
|
|
|
||
|
|
# Generate tokens and use MBD decoder
|
||
|
|
with torch.no_grad():
|
||
|
|
# Get tokens
|
||
|
|
gen_tokens = model.generate_tokens(descriptions)
|
||
|
|
|
||
|
|
# Decode with MBD
|
||
|
|
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||
|
|
|
||
|
|
# Compare quality
|
||
|
|
print(f"Standard shape: {wav_standard.shape}")
|
||
|
|
print(f"MBD shape: {wav_mbd.shape}")
|
||
|
|
```
|
||
|
|
|
||
|
|
## API Server Deployment
|
||
|
|
|
||
|
|
### FastAPI server
|
||
|
|
|
||
|
|
```python
|
||
|
|
from fastapi import FastAPI, HTTPException
|
||
|
|
from pydantic import BaseModel
|
||
|
|
import torch
|
||
|
|
import torchaudio
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
import io
|
||
|
|
import base64
|
||
|
|
|
||
|
|
app = FastAPI()
|
||
|
|
|
||
|
|
# Load model at startup
|
||
|
|
model = None
|
||
|
|
|
||
|
|
@app.on_event("startup")
|
||
|
|
async def load_model():
|
||
|
|
global model
|
||
|
|
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||
|
|
model.set_generation_params(duration=10)
|
||
|
|
|
||
|
|
class GenerateRequest(BaseModel):
|
||
|
|
prompt: str
|
||
|
|
duration: float = 10.0
|
||
|
|
temperature: float = 1.0
|
||
|
|
cfg_coef: float = 3.0
|
||
|
|
|
||
|
|
class GenerateResponse(BaseModel):
|
||
|
|
audio_base64: str
|
||
|
|
sample_rate: int
|
||
|
|
duration: float
|
||
|
|
|
||
|
|
@app.post("/generate", response_model=GenerateResponse)
|
||
|
|
async def generate(request: GenerateRequest):
|
||
|
|
if model is None:
|
||
|
|
raise HTTPException(status_code=500, detail="Model not loaded")
|
||
|
|
|
||
|
|
try:
|
||
|
|
model.set_generation_params(
|
||
|
|
duration=min(request.duration, 30),
|
||
|
|
temperature=request.temperature,
|
||
|
|
cfg_coef=request.cfg_coef
|
||
|
|
)
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
wav = model.generate([request.prompt])
|
||
|
|
|
||
|
|
# Convert to bytes
|
||
|
|
buffer = io.BytesIO()
|
||
|
|
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||
|
|
buffer.seek(0)
|
||
|
|
|
||
|
|
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||
|
|
|
||
|
|
return GenerateResponse(
|
||
|
|
audio_base64=audio_base64,
|
||
|
|
sample_rate=32000,
|
||
|
|
duration=wav.shape[-1] / 32000
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
async def health():
|
||
|
|
return {"status": "ok", "model_loaded": model is not None}
|
||
|
|
|
||
|
|
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||
|
|
```
|
||
|
|
|
||
|
|
### Batch processing service
|
||
|
|
|
||
|
|
```python
|
||
|
|
import asyncio
|
||
|
|
from concurrent.futures import ThreadPoolExecutor
|
||
|
|
import torch
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
|
||
|
|
class MusicGenService:
|
||
|
|
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||
|
|
self.model = MusicGen.get_pretrained(model_name)
|
||
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||
|
|
self.lock = asyncio.Lock()
|
||
|
|
|
||
|
|
async def generate_async(self, prompt, duration=10):
|
||
|
|
"""Async generation with thread pool."""
|
||
|
|
loop = asyncio.get_event_loop()
|
||
|
|
|
||
|
|
def _generate():
|
||
|
|
with torch.no_grad():
|
||
|
|
self.model.set_generation_params(duration=duration)
|
||
|
|
return self.model.generate([prompt])
|
||
|
|
|
||
|
|
# Run in thread pool
|
||
|
|
wav = await loop.run_in_executor(self.executor, _generate)
|
||
|
|
return wav[0].cpu()
|
||
|
|
|
||
|
|
async def generate_batch_async(self, prompts, duration=10):
|
||
|
|
"""Process multiple prompts concurrently."""
|
||
|
|
tasks = [self.generate_async(p, duration) for p in prompts]
|
||
|
|
return await asyncio.gather(*tasks)
|
||
|
|
|
||
|
|
# Usage
|
||
|
|
service = MusicGenService()
|
||
|
|
|
||
|
|
async def main():
|
||
|
|
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||
|
|
results = await service.generate_batch_async(prompts)
|
||
|
|
return results
|
||
|
|
```
|
||
|
|
|
||
|
|
## Integration Patterns
|
||
|
|
|
||
|
|
### LangChain tool
|
||
|
|
|
||
|
|
```python
|
||
|
|
from langchain.tools import BaseTool
|
||
|
|
import torch
|
||
|
|
import torchaudio
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
import tempfile
|
||
|
|
|
||
|
|
class MusicGeneratorTool(BaseTool):
|
||
|
|
name = "music_generator"
|
||
|
|
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||
|
|
self.model.set_generation_params(duration=15)
|
||
|
|
|
||
|
|
def _run(self, description: str) -> str:
|
||
|
|
with torch.no_grad():
|
||
|
|
wav = self.model.generate([description])
|
||
|
|
|
||
|
|
# Save to temp file
|
||
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||
|
|
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||
|
|
return f"Generated music saved to: {f.name}"
|
||
|
|
|
||
|
|
async def _arun(self, description: str) -> str:
|
||
|
|
return self._run(description)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Gradio with advanced controls
|
||
|
|
|
||
|
|
```python
|
||
|
|
import gradio as gr
|
||
|
|
import torch
|
||
|
|
import torchaudio
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
|
||
|
|
models = {}
|
||
|
|
|
||
|
|
def load_model(model_size):
|
||
|
|
if model_size not in models:
|
||
|
|
model_name = f"facebook/musicgen-{model_size}"
|
||
|
|
models[model_size] = MusicGen.get_pretrained(model_name)
|
||
|
|
return models[model_size]
|
||
|
|
|
||
|
|
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||
|
|
model = load_model(model_size)
|
||
|
|
|
||
|
|
model.set_generation_params(
|
||
|
|
duration=duration,
|
||
|
|
temperature=temperature,
|
||
|
|
cfg_coef=cfg_coef,
|
||
|
|
top_k=top_k
|
||
|
|
)
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
wav = model.generate([prompt])
|
||
|
|
|
||
|
|
# Save
|
||
|
|
path = "output.wav"
|
||
|
|
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||
|
|
return path
|
||
|
|
|
||
|
|
demo = gr.Interface(
|
||
|
|
fn=generate,
|
||
|
|
inputs=[
|
||
|
|
gr.Textbox(label="Prompt", lines=3),
|
||
|
|
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||
|
|
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||
|
|
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||
|
|
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||
|
|
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||
|
|
],
|
||
|
|
outputs=gr.Audio(label="Generated Music"),
|
||
|
|
title="MusicGen Advanced",
|
||
|
|
allow_flagging="never"
|
||
|
|
)
|
||
|
|
|
||
|
|
demo.launch(share=True)
|
||
|
|
```
|
||
|
|
|
||
|
|
## Audio Processing Pipeline
|
||
|
|
|
||
|
|
### Post-processing chain
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch
|
||
|
|
import torchaudio
|
||
|
|
import torchaudio.transforms as T
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
class AudioPostProcessor:
|
||
|
|
def __init__(self, sample_rate=32000):
|
||
|
|
self.sample_rate = sample_rate
|
||
|
|
|
||
|
|
def normalize(self, audio, target_db=-14.0):
|
||
|
|
"""Normalize audio to target loudness."""
|
||
|
|
rms = torch.sqrt(torch.mean(audio ** 2))
|
||
|
|
target_rms = 10 ** (target_db / 20)
|
||
|
|
gain = target_rms / (rms + 1e-8)
|
||
|
|
return audio * gain
|
||
|
|
|
||
|
|
def fade_in_out(self, audio, fade_duration=0.1):
|
||
|
|
"""Apply fade in/out."""
|
||
|
|
fade_samples = int(fade_duration * self.sample_rate)
|
||
|
|
|
||
|
|
# Create fade curves
|
||
|
|
fade_in = torch.linspace(0, 1, fade_samples)
|
||
|
|
fade_out = torch.linspace(1, 0, fade_samples)
|
||
|
|
|
||
|
|
# Apply fades
|
||
|
|
audio[..., :fade_samples] *= fade_in
|
||
|
|
audio[..., -fade_samples:] *= fade_out
|
||
|
|
|
||
|
|
return audio
|
||
|
|
|
||
|
|
def apply_reverb(self, audio, decay=0.5):
|
||
|
|
"""Apply simple reverb effect."""
|
||
|
|
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||
|
|
impulse[0] = 1.0
|
||
|
|
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||
|
|
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||
|
|
|
||
|
|
# Convolve
|
||
|
|
audio = torch.nn.functional.conv1d(
|
||
|
|
audio.unsqueeze(0),
|
||
|
|
impulse.unsqueeze(0).unsqueeze(0),
|
||
|
|
padding=len(impulse) // 2
|
||
|
|
).squeeze(0)
|
||
|
|
|
||
|
|
return audio
|
||
|
|
|
||
|
|
def process(self, audio):
|
||
|
|
"""Full processing pipeline."""
|
||
|
|
audio = self.normalize(audio)
|
||
|
|
audio = self.fade_in_out(audio)
|
||
|
|
return audio
|
||
|
|
|
||
|
|
# Usage with MusicGen
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
|
||
|
|
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||
|
|
model.set_generation_params(duration=10)
|
||
|
|
|
||
|
|
wav = model.generate(["chill ambient music"])
|
||
|
|
processor = AudioPostProcessor()
|
||
|
|
wav_processed = processor.process(wav[0].cpu())
|
||
|
|
|
||
|
|
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||
|
|
```
|
||
|
|
|
||
|
|
## Evaluation
|
||
|
|
|
||
|
|
### Audio quality metrics
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch
|
||
|
|
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||
|
|
from audiocraft.data.audio import audio_read
|
||
|
|
|
||
|
|
def evaluate_generation(audio_path, text_prompt):
|
||
|
|
"""Evaluate generated audio quality."""
|
||
|
|
# Load audio
|
||
|
|
wav, sr = audio_read(audio_path)
|
||
|
|
|
||
|
|
# CLAP consistency (text-audio alignment)
|
||
|
|
clap_metric = CLAPTextConsistencyMetric()
|
||
|
|
clap_score = clap_metric.compute(wav, [text_prompt])
|
||
|
|
|
||
|
|
return {
|
||
|
|
"clap_score": clap_score,
|
||
|
|
"duration": wav.shape[-1] / sr
|
||
|
|
}
|
||
|
|
|
||
|
|
# Batch evaluation
|
||
|
|
def evaluate_batch(generations):
|
||
|
|
"""Evaluate multiple generations."""
|
||
|
|
results = []
|
||
|
|
for gen in generations:
|
||
|
|
result = evaluate_generation(gen["path"], gen["prompt"])
|
||
|
|
result["prompt"] = gen["prompt"]
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
# Aggregate
|
||
|
|
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||
|
|
return {
|
||
|
|
"individual": results,
|
||
|
|
"average_clap": avg_clap
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
## Model Comparison
|
||
|
|
|
||
|
|
### MusicGen variants benchmark
|
||
|
|
|
||
|
|
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||
|
|
|-------|------------|----------------------|------|
|
||
|
|
| musicgen-small | 0.35 | ~5s | 2GB |
|
||
|
|
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||
|
|
| musicgen-large | 0.48 | ~30s | 8GB |
|
||
|
|
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||
|
|
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||
|
|
|
||
|
|
### Prompt engineering tips
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Good prompts - specific and descriptive
|
||
|
|
good_prompts = [
|
||
|
|
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||
|
|
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||
|
|
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||
|
|
]
|
||
|
|
|
||
|
|
# Bad prompts - too vague
|
||
|
|
bad_prompts = [
|
||
|
|
"nice music",
|
||
|
|
"song",
|
||
|
|
"good beat"
|
||
|
|
]
|
||
|
|
|
||
|
|
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||
|
|
```
|