Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ The main benefit of this method is that the build is reproducible since there is
uv pip install --group pre_build --no-build-isolation
uv pip install --group compile_xformers --no-build-isolation
uv sync
uv run python download_blt_weights.py
uv run python demo.py "A BLT has"
```

Expand Down
6 changes: 3 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


def main(prompt: str, model_name: str = "blt-1b"):
assert model_name in ['blt-1b', 'blt-7b']
model_name = model_name.replace('-', '_')
distributed_args = DistributedArgs()
distributed_args.configure_world()
if not torch.distributed.is_initialized():
Expand All @@ -25,9 +27,7 @@ def main(prompt: str, model_name: str = "blt-1b"):
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True
print("Loading entropy model and patcher")
patcher_args.entropy_model_checkpoint_dir = os.path.join(
checkpoint_path, "entropy_model"
)
patcher_args.entropy_model_checkpoint_dir = os.path.join("hf-weights", "entropy_model")
patcher = patcher_args.build()
prompts = [prompt]
outputs = generate_nocache(
Expand Down
5 changes: 2 additions & 3 deletions download_blt_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from huggingface_hub import snapshot_download


def main(models: list[str] = ["blt-1b", "blt-7b"]):
def main():
if not os.path.exists("hf-weights"):
os.makedirs("hf-weights")
for model in models:
snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}")
snapshot_download(f"facebook/blt", local_dir=f"hf-weights")


if __name__ == "__main__":
Expand Down
Loading