Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea
/target
.fastembed_cache
**/.claude/settings.local.json
Expand Down
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ tokio = {version = "1.0.0", features = ["full"]}
url = "2.5.4"
thiserror = "2.0.12"
fastembed = "4.9"
faiss = "0.12.1"
faiss = "0.13.0"
askama = "0.14.0"
prometheus = "0.14.0"
chrono = { version = "0.4", features = ["serde"] }
Expand Down
120 changes: 106 additions & 14 deletions src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,127 @@ use crate::clients::client::Client;
use crate::clients::http_client::HttpClient;
use crate::embedding::fastembed::FastEmbedService;
use crate::embedding::service::EmbeddingService;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum AppStateError {
#[error("Lock poisoned: {0}")]
LockPoisoned(&'static str),

#[error("Invalid namespace: {0}")]
InvalidNamespace(&'static str),
}

pub struct AppState {
// client for upstream LLM requests
pub http_client: Box<dyn Client>,
pub embedding_service: Box<dyn EmbeddingService>,
pub cache: Box<dyn Cache<Vec<u8>>>,
caches: RwLock<HashMap<String, Arc<Box<dyn Cache<Vec<u8>>>>>>,
similarity_threshold: f32,
eviction_policy: EvictionPolicy,
}

impl AppState {
pub fn new(semantic_threshold: f32, eviction_policy: EvictionPolicy) -> Self {
// client for upstream LLM requests
let http_client = Box::new(HttpClient::new());
// cache fields
let embedding_service = Box::new(FastEmbedService::new());
const MAX_NAMESPACE_LENGTH: usize = 64;

pub fn new(similarity_threshold: f32, eviction_policy: EvictionPolicy) -> Self {
Self {
http_client: Box::new(HttpClient::new()),
embedding_service: Box::new(FastEmbedService::new()),
caches: RwLock::new(HashMap::new()),
similarity_threshold,
eviction_policy,
}
}

fn validate_namespace(namespace: &str) -> Result<(), AppStateError> {
if namespace.is_empty() {
return Err(AppStateError::InvalidNamespace("namespace cannot be empty"));
}

if namespace.len() > Self::MAX_NAMESPACE_LENGTH {
return Err(AppStateError::InvalidNamespace("namespace too long"));
}

// Check for valid characters: alphanumeric, underscore, hyphen
if !namespace
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
return Err(AppStateError::InvalidNamespace(
"invalid characters in namespace",
));
}

Ok(())
}

fn create_cache(&self) -> Box<dyn Cache<Vec<u8>>> {
let semantic_store = Box::new(FlatIPFaissStore::new(
embedding_service.get_dimensionality(),
self.embedding_service.get_dimensionality(),
));
let response_store = ResponseStore::new();
// create cache
let cache = Box::new(CacheImpl::new(
Box::new(CacheImpl::new(
semantic_store,
response_store,
semantic_threshold,
eviction_policy,
));
// put service dependencies into app state
self.similarity_threshold,
self.eviction_policy,
))
}

pub fn get_cache(
&self,
namespace: &str,
) -> Result<Arc<Box<dyn Cache<Vec<u8>>>>, AppStateError> {
Self::validate_namespace(namespace)?;

// Try read lock first to check if cache exists
{
let read_guard = self
.caches
.read()
.map_err(|_| AppStateError::LockPoisoned("caches read lock poisoned"))?;
if let Some(cache) = read_guard.get(namespace) {
return Ok(Arc::clone(cache));
}
}

// Cache doesn't exist, acquire write lock to create it
let mut write_guard = self
.caches
.write()
.map_err(|_| AppStateError::LockPoisoned("caches write lock poisoned"))?;

// Double-check: another thread might have created it while we were waiting
let cache = write_guard
.entry(namespace.to_string())
.or_insert_with(|| Arc::new(self.create_cache()))
.clone();

Ok(cache)
}
}

#[cfg(test)]
impl AppState {
pub fn new_with_cache_for_test(
http_client: Box<dyn Client>,
embedding_service: Box<dyn EmbeddingService>,
cache: Box<dyn Cache<Vec<u8>>>,
) -> Self {
use crate::utils::header_utils::DEFAULT_NAMESPACE;
use std::collections::HashMap;
let mut caches = HashMap::new();
caches.insert(DEFAULT_NAMESPACE.to_string(), Arc::new(cache));

Self {
http_client,
embedding_service,
cache,
caches: RwLock::new(caches),
similarity_threshold: 0.9,
eviction_policy: EvictionPolicy::EntryLimit(100),
}
}
}
2 changes: 1 addition & 1 deletion src/cache/cache_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::cache::response_store::ResponseStore;
use crate::metrics::metrics::CACHE_SIZE;
use tracing::{debug, info};

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub enum EvictionPolicy {
EntryLimit(usize),
MemoryLimitMb(usize), // Could also implement a "combined" of both limits
Expand Down
20 changes: 10 additions & 10 deletions src/embedding/fastembed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@ impl FastEmbedService {
model_name: EmbeddingModel::AllMiniLML6V2,
}
}

pub fn get_dimensionality(&self) -> u32 {
match &self.model_name {
EmbeddingModel::AllMiniLML6V2 => 384,
_ => panic!(
"{}",
EmbeddingError::SetupError(String::from("Embedding model with unknown size",))
),
}
}
}

impl EmbeddingService for FastEmbedService {
Expand All @@ -43,4 +33,14 @@ impl EmbeddingService for FastEmbedService {

Ok(embeddings.into_iter().next().unwrap())
}

fn get_dimensionality(&self) -> u32 {
match &self.model_name {
EmbeddingModel::AllMiniLML6V2 => 384,
_ => panic!(
"{}",
EmbeddingError::SetupError(String::from("Embedding model with unknown size",))
),
}
}
}
1 change: 1 addition & 0 deletions src/embedding/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ use crate::embedding::error::EmbeddingError;
#[automock]
pub trait EmbeddingService: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
fn get_dimensionality(&self) -> u32;
}
Loading
Loading