Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,32 @@ EOF
--vae models/vae-BF16.gguf
```

**Lego** (`--lego <track>` + `--src-audio`):
**Lego** (`"lego"` in JSON + `--src-audio`):
generates a new instrument track layered over an existing backing track.
Only the **base model** (`acestep-v15-base`) supports lego mode.
The track name is passed on the CLI; set `audio_cover_strength=1.0` in the
request so the source audio guides all DiT steps.
See `examples/lego.json` and `examples/lego.sh`.

```bash
cat > /tmp/lego.json << 'EOF'
{
"caption": "electric guitar riff, funk guitar, house music, instrumental",
"lyrics": "[Instrumental]",
"lego": "guitar",
"inference_steps": 50,
"guidance_scale": 7.0,
"shift": 1.0
}
EOF

./build/dit-vae \
--src-audio backing-track.wav \
--request /tmp/lego.json \
--text-encoder models/Qwen3-Embedding-0.6B-Q8_0.gguf \
--dit models/acestep-v15-base-Q8_0.gguf \
--vae models/vae-BF16.gguf \
--wav
```

Available track names: `vocals`, `backing_vocals`, `drums`, `bass`, `guitar`,
`keyboard`, `percussion`, `strings`, `synth`, `fx`, `brass`, `woodwinds`.

Expand Down Expand Up @@ -295,7 +314,8 @@ the LLM fills them, or a sensible runtime default is applied.
"shift": 3.0,
"audio_cover_strength": 0.5,
"repainting_start": -1,
"repainting_end": -1
"repainting_end": -1,
"lego": ""
}
```

Expand Down Expand Up @@ -363,6 +383,15 @@ the DiT regenerates the `[start, end)` time region while preserving everything
else. `-1` on start means 0s (beginning), `-1` on end means source duration
(end). Error if end <= start after resolve. `audio_cover_strength` is ignored.

**`lego`** (string, default `""` = inactive)
Track name for lego mode. Requires `--src-audio` and the **base model**.
Valid names: `vocals`, `backing_vocals`, `drums`, `bass`, `guitar`,
`keyboard`, `percussion`, `strings`, `synth`, `fx`, `brass`, `woodwinds`.
When set, passes the source audio to the DiT as context and builds the
instruction `"Generate the {TRACK} track based on the audio context:"`.
`audio_cover_strength` is forced to 1.0 (all steps see the source audio).
Use `inference_steps=50`, `guidance_scale=7.0`, `shift=1.0` for base model.

### LM sampling (ace-qwen3)

**`lm_temperature`** (float, default `0.85`)
Expand Down
8 changes: 5 additions & 3 deletions examples/lego.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"caption": "electric guitar riff, funk guitar, house music, instrumental",
"audio_cover_strength": 1.0,
"caption": "",
"lyrics": "[Instrumental]",
"lego": "guitar",
"inference_steps": 50,
"guidance_scale": 7.0
"guidance_scale": 7.0,
"shift": 1.0
}
1 change: 0 additions & 1 deletion examples/lego.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ set -eu
# Step 2: lego guitar on the generated track (base model required)
../build/dit-vae \
--src-audio simple00.wav \
--lego guitar \
--request lego.json \
--text-encoder ../models/Qwen3-Embedding-0.6B-Q8_0.gguf \
--dit ../models/acestep-v15-base-Q8_0.gguf \
Expand Down
9 changes: 9 additions & 0 deletions src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void request_init(AceRequest * r) {
r->audio_cover_strength = 0.5f;
r->repainting_start = -1.0f;
r->repainting_end = -1.0f;
r->lego = "";
}

// JSON string escape / unescape
Expand Down Expand Up @@ -321,6 +322,8 @@ bool request_parse(AceRequest * r, const char * path) {
r->repainting_start = (float) atof(v.c_str());
} else if (k == "repainting_end") {
r->repainting_end = (float) atof(v.c_str());
} else if (k == "lego") {
r->lego = v;
}
}

Expand Down Expand Up @@ -356,6 +359,9 @@ bool request_write(const AceRequest * r, const char * path) {
fprintf(f, " \"audio_cover_strength\": %.2f,\n", r->audio_cover_strength);
fprintf(f, " \"repainting_start\": %.1f,\n", r->repainting_start);
fprintf(f, " \"repainting_end\": %.1f,\n", r->repainting_end);
if (!r->lego.empty()) {
fprintf(f, " \"lego\": \"%s\",\n", json_escape(r->lego).c_str());
}
// audio_codes last (no trailing comma)
fprintf(f, " \"audio_codes\": \"%s\"\n", json_escape(r->audio_codes).c_str());
fprintf(f, "}\n");
Expand All @@ -380,5 +386,8 @@ void request_dump(const AceRequest * r, FILE * f) {
if (r->repainting_start >= 0.0f || r->repainting_end >= 0.0f) {
fprintf(f, " repaint: start=%.1f end=%.1f\n", r->repainting_start, r->repainting_end);
}
if (!r->lego.empty()) {
fprintf(f, " lego: %s\n", r->lego.c_str());
}
fprintf(f, " audio_codes: %s\n", r->audio_codes.empty() ? "(none)" : "(present)");
}
6 changes: 6 additions & 0 deletions src/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ struct AceRequest {
// -1 on start means 0s, -1 on end means source duration.
float repainting_start; // -1
float repainting_end; // -1

// lego mode (requires --src-audio, base model only)
// Track name from TRACK_NAMES: vocals, backing_vocals, drums, bass, guitar,
// keyboard, percussion, strings, synth, fx, brass, woodwinds.
// Empty = not lego. Sets instruction and forces full-range repaint.
std::string lego; // ""
};

// Initialize all fields to defaults (matches Python GenerationParams defaults)
Expand Down
91 changes: 55 additions & 36 deletions tools/dit-vae.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "vae-enc.h"
#include "vae.h"

#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <cstring>
Expand All @@ -32,11 +33,6 @@ static void print_usage(const char * prog) {
" --vae <gguf> VAE GGUF file\n\n"
"Reference audio:\n"
" --src-audio <file> Source audio (WAV or MP3, any sample rate)\n\n"
"Lego mode (base model only, requires --src-audio):\n"
" --lego <track> Generate a track over the source audio context\n"
" Track names: vocals, backing_vocals, drums, bass,\n"
" guitar, keyboard, percussion, strings, synth,\n"
" fx, brass, woodwinds\n\n"
"LoRA:\n"
" --lora <path> LoRA safetensors file or directory\n"
" --lora-scale <float> LoRA scaling factor (default: 1.0)\n\n"
Expand Down Expand Up @@ -88,7 +84,6 @@ int main(int argc, char ** argv) {
const char * dit_gguf = NULL;
const char * vae_gguf = NULL;
const char * src_audio_path = NULL;
const char * lego_track = NULL; // --lego <track>
const char * dump_dir = NULL;
const char * lora_path = NULL;
float lora_scale = 1.0f;
Expand All @@ -113,8 +108,6 @@ int main(int argc, char ** argv) {
vae_gguf = argv[++i];
} else if (strcmp(argv[i], "--src-audio") == 0 && i + 1 < argc) {
src_audio_path = argv[++i];
} else if (strcmp(argv[i], "--lego") == 0 && i + 1 < argc) {
lego_track = argv[++i];
} else if (strcmp(argv[i], "--lora") == 0 && i + 1 < argc) {
lora_path = argv[++i];
} else if (strcmp(argv[i], "--lora-scale") == 0 && i + 1 < argc) {
Expand Down Expand Up @@ -152,10 +145,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "[CLI] ERROR: --batch must be 1..9\n");
return 1;
}
if (lego_track && !src_audio_path) {
fprintf(stderr, "[CLI] ERROR: --lego requires --src-audio\n");
return 1;
}
if (!dit_gguf) {
fprintf(stderr, "[CLI] ERROR: --dit required\n");
print_usage(argv[0]);
Expand Down Expand Up @@ -198,12 +187,6 @@ int main(int argc, char ** argv) {
if (gf_load(&gf, dit_gguf)) {
is_turbo = gf_get_bool(gf, "acestep.is_turbo");
const void * sl_data = gf_get_data(gf, "silence_latent");
if (lego_track && is_turbo) {
fprintf(stderr, "[CLI] ERROR: --lego requires the base DiT model\n");
gf_close(&gf);
dit_ggml_free(&model);
return 1;
}
if (sl_data) {
silence_full.resize(15000 * 64);
memcpy(silence_full.data(), sl_data, 15000 * 64 * sizeof(float));
Expand Down Expand Up @@ -301,11 +284,43 @@ int main(int argc, char ** argv) {
fprintf(stderr, "[Request] ERROR: failed to parse %s, skipping\n", rpath);
continue;
}
if (req.caption.empty()) {
if (req.caption.empty() && req.lego.empty()) {
fprintf(stderr, "[Request] ERROR: caption is empty in %s, skipping\n", rpath);
continue;
}

// Lego mode validation (base model only, requires --src-audio)
bool is_lego = !req.lego.empty();
if (is_lego) {
if (!src_audio_path) {
fprintf(stderr, "[Lego] ERROR: lego requires --src-audio\n");
return 1;
}
if (is_turbo) {
fprintf(stderr, "[Lego] ERROR: lego requires the base DiT model (turbo detected)\n");
return 1;
}
// Reference project: TRACK_NAMES (constants.py)
static const char * allowed[] = {
"vocals", "backing_vocals", "drums", "bass", "guitar", "keyboard",
"percussion", "strings", "synth", "fx", "brass", "woodwinds",
};
bool valid = false;
for (int k = 0; k < 12; k++) {
if (req.lego == allowed[k]) {
valid = true;
break;
}
}
if (!valid) {
fprintf(stderr, "[Lego] ERROR: '%s' is not a valid track name\n", req.lego.c_str());
fprintf(stderr,
" Valid: vocals, backing_vocals, drums, bass, guitar, keyboard,\n"
" percussion, strings, synth, fx, brass, woodwinds\n");
return 1;
}
}

// Extract params
const char * caption = req.caption.c_str();
const char * lyrics = req.lyrics.c_str();
Expand Down Expand Up @@ -424,32 +439,36 @@ int main(int argc, char ** argv) {
}

// 2. Build formatted prompts
// Reference project uses opposite-sounding instructions (constants.py):
// Reference project instruction templates (constants.py TASK_INSTRUCTIONS):
// text2music = "Fill the audio semantic mask..."
// cover = "Generate audio semantic tokens..."
// repaint = "Repaint the mask area..."
// lego = "Generate the {track} track based on the audio context:"
// lego = "Generate the {TRACK_NAME} track based on the audio context:"
// Auto-switches to cover when audio_codes are present
bool is_cover = have_cover || !codes_vec.empty();

// Lego: build instruction from the track name supplied via --lego <track>
char lego_instruction[256] = {};
const char * instruction;
if (lego_track) {
snprintf(lego_instruction, sizeof(lego_instruction),
"Generate the %s track based on the audio context:", lego_track);
instruction = lego_instruction;
fprintf(stderr, "[Lego] track=%s\n", lego_track);
bool is_cover = have_cover || !codes_vec.empty();
std::string instruction_str;
if (is_lego) {
// Lego mode: force audio_cover_strength=1.0 so all DiT steps see the source audio
req.audio_cover_strength = 1.0f;
fprintf(stderr, "[Lego] track=%s, cover path, strength=1.0\n", req.lego.c_str());
// Reference project (task_utils.py:86): track name is UPPERCASE
std::string track_upper = req.lego;
for (char & c : track_upper) {
c = (char) toupper((unsigned char) c);
}
instruction_str = "Generate the " + track_upper + " track based on the audio context:";
} else if (is_repaint) {
instruction_str = "Repaint the mask area based on the given conditions:";
} else if (is_cover) {
instruction_str = "Generate audio semantic tokens based on the given conditions:";
} else {
instruction = is_repaint ? "Repaint the mask area based on the given conditions:" :
is_cover ? "Generate audio semantic tokens based on the given conditions:" :
"Fill the audio semantic mask based on the given conditions:";
instruction_str = "Fill the audio semantic mask based on the given conditions:";
}

char metas[512];
snprintf(metas, sizeof(metas), "- bpm: %s\n- timesignature: %s\n- keyscale: %s\n- duration: %d seconds\n", bpm,
timesig, keyscale, (int) duration);
std::string text_str = std::string("# Instruction\n") + instruction + "\n\n" + "# Caption\n" + caption +
std::string text_str = std::string("# Instruction\n") + instruction_str + "\n\n" + "# Caption\n" + caption +
"\n\n" + "# Metas\n" + metas + "<|endoftext|>\n";

std::string lyric_str = std::string("# Languages\n") + language + "\n\n# Lyric\n" + lyrics + "<|endoftext|>";
Expand Down Expand Up @@ -567,7 +586,7 @@ int main(int argc, char ** argv) {
}

// Build context: [T, ctx_ch] = src_latents[64] + chunk_mask[64]
// Cover: src = cover_latents, mask = 1.0 everywhere
// Cover/Lego: src = cover_latents, mask = 1.0 everywhere
// Repaint: src = silence in region / cover outside, mask = 1.0 in region / 0.0 outside
// Passthrough: detokenized FSQ codes + silence padding, mask = 1.0
// Text2music: silence only, mask = 1.0
Expand Down
Loading