Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions crates/openfang-kernel/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,8 @@ impl OpenFangKernel {
if let Some(ref provider) = config.memory.embedding_provider {
// Explicit config takes priority — use the configured embedding model
let api_key_env = config.memory.embedding_api_key_env.as_deref().unwrap_or("");
match create_embedding_driver(provider, configured_model, api_key_env) {
let url = config.provider_urls.get(provider).map(|s| s.as_str());
match create_embedding_driver(provider, configured_model, api_key_env, url) {
Ok(d) => {
info!(provider = %provider, model = %configured_model, "Embedding driver configured from memory config");
Some(Arc::from(d))
Expand All @@ -781,7 +782,8 @@ impl OpenFangKernel {
} else {
configured_model.as_str()
};
match create_embedding_driver("openai", model, "OPENAI_API_KEY") {
let url = config.provider_urls.get("openai").map(|s| s.as_str());
match create_embedding_driver("openai", model, "OPENAI_API_KEY", url) {
Ok(d) => {
info!("Embedding driver auto-detected: OpenAI");
Some(Arc::from(d))
Expand All @@ -798,7 +800,8 @@ impl OpenFangKernel {
} else {
configured_model.as_str()
};
match create_embedding_driver("ollama", model, "") {
let url = config.provider_urls.get("ollama").map(|s| s.as_str());
match create_embedding_driver("ollama", model, "", url) {
Ok(d) => {
info!("Embedding driver auto-detected: Ollama (local)");
Some(Arc::from(d))
Expand Down
69 changes: 56 additions & 13 deletions crates/openfang-runtime/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,25 +179,30 @@ pub fn create_embedding_driver(
provider: &str,
model: &str,
api_key_env: &str,
base_url_override: Option<&str>,
) -> Result<Box<dyn EmbeddingDriver + Send + Sync>, EmbeddingError> {
let api_key = if api_key_env.is_empty() {
String::new()
} else {
std::env::var(api_key_env).unwrap_or_default()
};

let base_url = match provider {
"openai" => OPENAI_BASE_URL.to_string(),
"groq" => GROQ_BASE_URL.to_string(),
"together" => TOGETHER_BASE_URL.to_string(),
"fireworks" => FIREWORKS_BASE_URL.to_string(),
"mistral" => MISTRAL_BASE_URL.to_string(),
"ollama" => OLLAMA_BASE_URL.to_string(),
"vllm" => VLLM_BASE_URL.to_string(),
"lmstudio" => LMSTUDIO_BASE_URL.to_string(),
other => {
warn!("Unknown embedding provider '{other}', using OpenAI-compatible format");
format!("https://{other}/v1")
let base_url = if let Some(url) = base_url_override {
url.to_string()
} else {
match provider {
"openai" => OPENAI_BASE_URL.to_string(),
"groq" => GROQ_BASE_URL.to_string(),
"together" => TOGETHER_BASE_URL.to_string(),
"fireworks" => FIREWORKS_BASE_URL.to_string(),
"mistral" => MISTRAL_BASE_URL.to_string(),
"ollama" => OLLAMA_BASE_URL.to_string(),
"vllm" => VLLM_BASE_URL.to_string(),
"lmstudio" => LMSTUDIO_BASE_URL.to_string(),
other => {
warn!("Unknown embedding provider '{other}', using OpenAI-compatible format");
format!("https://{other}/v1")
}
}
};

Expand Down Expand Up @@ -351,8 +356,46 @@ mod tests {
#[test]
fn test_create_embedding_driver_ollama() {
// Should succeed even without API key (ollama is local)
let driver = create_embedding_driver("ollama", "all-MiniLM-L6-v2", "");
let driver = create_embedding_driver("ollama", "all-MiniLM-L6-v2", "", None);
assert!(driver.is_ok());
assert_eq!(driver.unwrap().dimensions(), 384);
}

#[test]
fn test_create_embedding_driver_with_url_override() {
let driver = create_embedding_driver(
"ollama",
"nomic-embed-text",
"",
Some("http://192.168.1.100:11434/v1"),
);
assert!(driver.is_ok());

let driver = create_embedding_driver(
"ollama",
"all-MiniLM-L6-v2",
"",
Some("https://ollama.remote.com/v1"),
);
assert!(driver.is_ok());
assert_eq!(driver.unwrap().dimensions(), 384);
}

#[test]
fn test_create_embedding_driver_openai_with_url_override() {
// Test that URL override works for OpenAI provider
let driver = create_embedding_driver(
"openai",
"text-embedding-3-small",
"OPENAI_API_KEY",
Some("https://custom-openai-compatible.example.com/v1"),
);
assert!(driver.is_ok());
}

#[test]
fn test_create_embedding_driver_fallback_to_default_url() {
let driver = create_embedding_driver("ollama", "nomic-embed-text", "", None);
assert!(driver.is_ok());
}
}