Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6cb1779
more suggestion fixing
RedmanPlus Jun 16, 2037
9d299b7
Merge remote-tracking branch 'origin/master'
RedmanPlus Jun 16, 2037
079da0d
add whisper support to inference
RedmanPlus Jun 22, 2024
38c9078
add routing between search and encode instances
RedmanPlus Jun 22, 2024
9472fdd
Merge remote-tracking branch 'origin/master' into whisper
RedmanPlus Jun 22, 2024
e45ccde
fix docker compose
RedmanPlus Jun 22, 2024
b291d45
fix docker compose
RedmanPlus Jun 22, 2024
afffbc8
add whisper.cpp preload
RedmanPlus Jun 22, 2024
502aa12
change port schematics
RedmanPlus Jun 22, 2024
bc45b6e
fix stash errors
RedmanPlus Jun 22, 2024
940261e
fix import errors in containers
RedmanPlus Jun 22, 2024
dbfc95c
add whisper cpp to dep list
RedmanPlus Jun 22, 2024
e92c926
change db host name
RedmanPlus Jun 22, 2024
ad996e7
fix paths to inference
RedmanPlus Jun 22, 2024
898b711
add tempfiles for whisper
RedmanPlus Jun 22, 2024
f1012c1
fix tempfiles + translate
RedmanPlus Jun 22, 2024
d1caad3
redo file submission to whisper
RedmanPlus Jun 22, 2024
85860a8
fix prompts + add logging
RedmanPlus Jun 23, 2024
2456c34
adding threads to the problem
RedmanPlus Jun 23, 2024
6a7d81e
piping video to wav and adding logging
RedmanPlus Jun 23, 2024
89d75a3
adding summary api
RedmanPlus Jun 23, 2024
86e8aeb
adding more logging to whisper for debug purposes
RedmanPlus Jun 23, 2024
09c2bd4
fix importing error
RedmanPlus Jun 23, 2024
5d848d3
fixing suffix errors + debug
RedmanPlus Jun 23, 2024
20c0f42
ditching wav formatting for now
RedmanPlus Jun 23, 2024
c9e5051
add logging for clip encoding result
RedmanPlus Jun 23, 2024
9cc7753
fix video vectorization
RedmanPlus Jun 23, 2024
22eeb6e
fix suggestion id generation
RedmanPlus Jun 23, 2024
28775d3
redo ports again
RedmanPlus Jun 23, 2024
2fb6192
redo ports again
RedmanPlus Jun 23, 2024
0de8abf
more port meddling
RedmanPlus Jun 23, 2024
784fd54
save urls to documents instead of uris(they dont work)
RedmanPlus Jun 23, 2024
ec62d4f
move to faster-whisper to improve performance + optimizations
RedmanPlus Jun 24, 2024
84affe8
add lost env vars
RedmanPlus Jun 24, 2024
bc710eb
fix snapshot loading
RedmanPlus Jun 24, 2024
2dc3f87
change whisper model
RedmanPlus Jun 24, 2024
e39bc15
change quantization to int8
RedmanPlus Jun 24, 2024
a07935a
add sentencepiece to list of deps
RedmanPlus Jun 24, 2024
87db31f
scrap whisper for time being
RedmanPlus Jun 24, 2024
9aa10fd
(upd) Key frame search, trunc text for features
arseny-chebyshev Jun 24, 2024
7d17375
(upd) Key frame search, trunc text for features
arseny-chebyshev Jun 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ version: "3.10"

services:
db:
container_name: chroma_db
image: chromadb/chroma:latest
volumes:
- chroma-data:/chroma/chroma
Expand All @@ -11,18 +10,16 @@ services:
- "8000:8000"

cache:
container_name: request_cache
image: memcached:latest
ports:
- "11211:11211"
restart: always

inference:
encode:
build:
context: ./inference
dockerfile: Dockerfile
container_name: inference
command: uvicorn clip:app --host "0.0.0.0" --port 8040
command: uvicorn main:app --host "0.0.0.0" --port 8040 --log-config=log_conf.yaml
restart: unless-stopped
volumes:
- inference-model-data:/app/model_data
Expand All @@ -31,18 +28,31 @@ services:
ports:
- "8040:8040"

search:
build:
context: ./inference
dockerfile: Dockerfile
command: uvicorn main:app --host "0.0.0.0" --port 8050 --log-config=log_conf.yaml
restart: unless-stopped
volumes:
- inference-model-data:/app/model_data
env_file:
- inference/.env.dist
ports:
- "8050:8050"

main:
build:
context: ./main
dockerfile: Dockerfile
container_name: main_gateway
command: uvicorn main:app --host "0.0.0.0" --port 80
restart: unless-stopped
volumes:
- main-model-data:/app/model_data
depends_on:
- db
- inference
- encode
- search
- cache
env_file:
- main/.env.dist
Expand Down
1 change: 1 addition & 0 deletions inference/.env.dist
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
CLIP_MODEL=laion/CLIP-ViT-g-14-laion2B-s12B-b42K
TRANSLATION_MODEL=Helsinki-NLP/opus-mt-ru-en
8 changes: 3 additions & 5 deletions inference/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ ENV PYTHONUNBUFFERED 1

WORKDIR /app

COPY requirements.txt /app/

RUN apt-get update && apt-get install ffmpeg -y

COPY requirements.txt /app/
RUN python -m pip install --upgrade pip && pip install -r requirements.txt
COPY ./ /app/

EXPOSE 8040
CMD uvicorn clip:app --port 8040
COPY ./ /app/
91 changes: 44 additions & 47 deletions inference/clip.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,59 @@
import torch
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from dataclasses import dataclass
from logging import Logger
from typing import Callable, Literal

from PIL import Image
from pydantic import BaseModel
from deps import Model, Processor, lifespan
from frame_video import create_key_frames_for_video

app = FastAPI(lifespan=lifespan)

class EncodeRequest(BaseModel):
link: Optional[str] = None
description: Optional[str] = None

@app.get("/")
async def root():
return JSONResponse(content={"ok": True})

@app.post("/encode")
async def encode(request: EncodeRequest, processor: Processor, model: Model):
if not any((request.description, request.link)):
raise HTTPException(
status_code=400, detail="Please provide either 'description' as string or 'link' as video URL, or both."
)

text_features, image_features = None, None

if request.description:
text_inputs = processor(text=[request.description], return_tensors="pt", padding=True)
import torch
from transformers import CLIPModel, CLIPProcessor

from frame_video import VideoFrame, create_key_frames_for_video


@dataclass
class CLIP:
processor: CLIPProcessor
model: CLIPModel
logger: Logger

_create_key_frames_for_video: Callable[[str], list[VideoFrame]] = create_key_frames_for_video

def __call__(self, encode_source: str, encode_type: Literal["text"] | Literal["video"]) -> list[float]:
if encode_type == "text":
self.logger.info("Processing text input: %s, input length: %s", encode_source, len(encode_source))
return self._encode_text(encode_source)

if encode_type == "video":
self.logger.info("Processing video input: %s", encode_source)
return self._encode_video(encode_source)

def _encode_text(self, description: str) -> list[float]:
description = description[:65] # meet the processor max length
text_inputs = self.processor(text=[description], return_tensors="pt", padding=True)
with torch.no_grad():
text_features = model.get_text_features(**text_inputs)
text_features = self.model.get_text_features(**text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)

if request.link:
images = create_key_frames_for_video(request.link)

result = text_features.tolist()[0]
self.logger.info("Processed result vector - %s", result)
return result

def _encode_video(self, link: str) -> list[float]:
images = self._create_key_frames_for_video(link)
image_inputs = []
for image in images:
image = Image.open(image.file)
image_input = processor(images=image, return_tensors="pt")
image_input = self.processor(images=image, return_tensors="pt")
image_inputs.append(image_input)
with torch.no_grad():
image_features = model.get_image_features(**image_inputs[0])
image_features = self.model.get_image_features(**image_inputs[0])
for image_input in image_inputs[1:]:
image_feature = model.get_image_features(**image_input)
image_feature = self.model.get_image_features(**image_input)
image_features = torch.cat((image_features, image_feature), dim=0)

features = torch.mean(image_features, dim=0)
features /= features.norm(dim=-1, keepdim=True)

if request.description and request.link:
text_weight = 1.0
video_weight = 2.0 # Giving more importance to video
# Merged weighted vectors of text and video didn't work so well, leave off for now
unified_features = (text_features * text_weight + image_features * video_weight) / (text_weight + video_weight)
return {"features": image_features.tolist()[0]}

elif request.description:
return {"features": text_features.tolist()[0]}
result = features.tolist()
self.logger.info("Processed result vector - %s", result)
return result

elif request.link:
return {"features": image_features.tolist()[0]}
4 changes: 4 additions & 0 deletions inference/deps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import asynccontextmanager
import logging
from typing import Annotated

from fastapi import Depends, FastAPI, Request
Expand All @@ -9,10 +10,13 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
logger = logging.getLogger(__name__)
logger.info("Setting up CLIP model...")
app.state.clip_model = CLIPModel.from_pretrained(
Settings.clip_model,
cache_dir="./model_cache"
)
logger.info("Setting up CLIP processor...")
app.state.processor = CLIPProcessor.from_pretrained(
Settings.clip_model,
cache_dir="./model_cache"
Expand Down
16 changes: 11 additions & 5 deletions inference/frame_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from io import BytesIO
from dataclasses import dataclass
import requests
from scenedetect import detect, ContentDetector
from scenedetect import detect, ContentDetector, AdaptiveDetector

@dataclass
class VideoFrame:
Expand All @@ -14,19 +14,24 @@ class VideoFrame:
def create_key_frames_for_video(
video_link: str,
frame_change_threshold: float = 7.5,
min_scene_len: int = 10,
num_of_thumbnails: int = 10
) -> list[VideoFrame]:
frames: list[VideoFrame] = []
video_data = BytesIO(requests.get(video_link).content)
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(video_data.getvalue())
video_path = tmp_file.name
scenes = detect(video_path, ContentDetector(threshold=frame_change_threshold))
scenes = detect(
video_path=video_path,
detector=ContentDetector(threshold=frame_change_threshold, min_scene_len=min_scene_len)
)

# Gradually reduce number of key frames with a sliding window
# Gradually reduce number of key frames with a increasingly smaller steps
while len(scenes) > num_of_thumbnails:
scenes.pop()
scenes.pop(0)
step = len(scenes) / (num_of_thumbnails - 1)
to_remove_indices = [int(round(i * step)) for i in range(num_of_thumbnails)]
scenes = [scenes[i] for i in range(len(scenes)) if i not in to_remove_indices]
for i, scene in enumerate(scenes):
scene_start, _ = scene
frame_data = create_frame_in_ram(video_path, scene_start.get_timecode())
Expand All @@ -39,6 +44,7 @@ def create_key_frames_for_video(
return create_key_frames_for_video(
video_link=video_link,
frame_change_threshold=frame_change_threshold - 2.5,
min_scene_len=min_scene_len - 2 if min_scene_len > 2 else min_scene_len,
num_of_thumbnails=num_of_thumbnails
)
return frames
Expand Down
34 changes: 34 additions & 0 deletions inference/log_conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
version: 1
disable_existing_loggers: False
formatters:
default:
# "()": uvicorn.logging.DefaultFormatter
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
access:
# "()": uvicorn.logging.AccessFormatter
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
handlers:
default:
formatter: default
class: logging.StreamHandler
stream: ext://sys.stderr
access:
formatter: access
class: logging.StreamHandler
stream: ext://sys.stdout
loggers:
uvicorn.error:
level: INFO
handlers:
- default
propagate: no
uvicorn.access:
level: INFO
handlers:
- access
propagate: no
root:
level: INFO
handlers:
- default
propagate: no
41 changes: 41 additions & 0 deletions inference/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from deps import Model, Processor, lifespan
from clip import CLIP
from models import EncodeRequest, EncodeSearchRequest

app = FastAPI(lifespan=lifespan)
logger = logging.getLogger(__name__)

@app.get("/")
async def root():
return JSONResponse(content={"ok": True})

@app.post("/encode")
async def encode(
request: EncodeRequest,
processor: Processor,
model: Model,
):
logger.info("Initializing CLIP module...")
clip = CLIP(processor=processor, model=model, logger=logger)
logger.info("CLIP module successfully initialized")

video_features = clip(request.link, encode_type="video")
return {
"video": video_features,
}

@app.post("/encode-search")
async def encode_search(
request: EncodeSearchRequest, processor: Processor, model: Model
):
logger.info("Initializing CLIP module...")
clip = CLIP(processor=processor, model=model, logger=logger)
logger.info("CLIP module successfully initialized")

features = clip(request.query, encode_type="text")

return {"features": features}
10 changes: 10 additions & 0 deletions inference/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel


class EncodeRequest(BaseModel):
link: str
description: str | None = None


class EncodeSearchRequest(BaseModel):
query: str
2 changes: 2 additions & 0 deletions inference/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ pillow==10.3.0
scenedetect==0.6.3
opencv-python==4.10.0.82
environs==11.0.0
PyYAML>=6.0
sentencepiece==0.2.0
3 changes: 2 additions & 1 deletion inference/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@


class Settings:
clip_model: str = env.str("CLIP_MODEL")
clip_model: str = env.str("CLIP_MODEL", default="laion/CLIP-ViT-g-14-laion2B-s12B-b42K")
translation_model: str = env.str("TRANSLATION_MODEL", default="Helsinki-NLP/opus-mt-ru-en")
5 changes: 3 additions & 2 deletions main/.env.dist
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
CLIP_URL=http://inference:8040/encode
DB_HOST=chroma_db
ENCODE_CLIP_URL=http://encode:8040/
SEARCH_CLIP_URL=http://search:8050/
DB_HOST=db
DB_PORT=8000
9 changes: 6 additions & 3 deletions main/chroma.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from uuid import uuid4
import chromadb
from chromadb.server import Settings as ChromaSettings
from models import Feature
Expand Down Expand Up @@ -25,22 +26,24 @@ def __init__(

def add_feature(self, feature: Feature) -> None:
self.collection.add(
ids=[feature.link],
ids=[str(uuid4())],
embeddings=[feature.features],
documents=[feature.link],
metadatas=[{"feature_type": feature.feature_type}]
)

def search_relevant_videos(self, search_feature: Feature, top_k: int = 100) -> list[str]:
results = self.collection.query(
query_embeddings=search_feature.features,
n_results=top_k
)
return results['ids'][0]
return results['documents'][0]

def add_text_search_suggestion(self, suggestion_query: str) -> None:
subsearches = suggestion_query.split()
self.desc_collection.add(
documents=[suggestion_query] + subsearches,
ids=[str(hash(query)) for query in [suggestion_query] + subsearches]
ids=[str(uuid4()) for _ in [suggestion_query] + subsearches]
)

def get_text_search_suggestions(self, search_query: str, top_k: int = 20) -> list[str]:
Expand Down
Loading