diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 04e29c1b..9c03756f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,7 @@ jobs: miner: ${{ steps.changes.outputs.miner }} basilica-cli: ${{ steps.changes.outputs.basilica-cli }} basilica-sdk-python: ${{ steps.changes.outputs.basilica-sdk-python }} + basilica-sdk-rust: ${{ steps.changes.outputs.basilica-sdk-rust }} workspace: ${{ steps.changes.outputs.workspace }} steps: - uses: actions/checkout@v4 @@ -71,6 +72,12 @@ jobs: - 'crates/basilica-common/**' - 'Cargo.toml' - 'Cargo.lock' + basilica-sdk-rust: + - 'crates/basilica-sdk-rust/**' + - 'crates/basilica-sdk/**' + - 'crates/basilica-common/**' + - 'Cargo.toml' + - 'Cargo.lock' workspace: - 'Cargo.toml' - 'Cargo.lock' @@ -143,6 +150,7 @@ jobs: cargo clippy -p basilica-validator --all-targets --all-features -- -D warnings cargo clippy -p basilica-cli --all-targets --all-features -- -D warnings cargo clippy -p basilica-sdk --all-targets --all-features -- -D warnings + cargo clippy -p basilica-sdk-rust --all-targets --all-features -- -D warnings # Build and test validator build-validator: @@ -373,6 +381,49 @@ jobs: source .venv/bin/activate python -c 'import basilica; print(f"SDK imported successfully. API URL: {basilica.DEFAULT_API_URL}")' + # Build and test Rust SDK + test-rust-sdk: + runs-on: blacksmith-32vcpu-ubuntu-2404 + needs: changes + if: needs.changes.outputs.basilica-sdk-rust == 'true' || needs.changes.outputs.workspace == 'true' + strategy: + matrix: + rust-version: [stable] + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust-version }} + components: rustfmt, clippy + - uses: Swatinem/rust-cache@v2 + with: + shared-key: "shared-cache" + save-if: ${{ github.ref == 'refs/heads/main' }} + - name: Install protoc + uses: arduino/setup-protoc@v3 + with: + version: "25.x" + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Cache and install system dependencies + uses: awalsh128/cache-apt-pkgs-action@latest + with: + packages: pkg-config libssl-dev xxd mold clang + version: 1.0 + - name: Install cargo-nextest + uses: taiki-e/install-action@v2 + with: + tool: nextest + - name: Check formatting + run: cargo fmt -p basilica-sdk-rust -- --check + - name: Run clippy on Rust SDK + run: cargo clippy -p basilica-sdk-rust --all-targets --all-features -- -D warnings + - name: Build Rust SDK + run: cargo build -p basilica-sdk-rust --all-features + - name: Run tests + run: cargo nextest run -p basilica-sdk-rust --all-features --no-fail-fast + - name: Run doc tests + run: cargo test -p basilica-sdk-rust --doc + # Final status check ci-success: runs-on: blacksmith-32vcpu-ubuntu-2404 @@ -383,6 +434,7 @@ jobs: - build-miner - build-cli - test-python-sdk + - test-rust-sdk if: always() steps: - name: Check if all jobs succeeded @@ -392,7 +444,8 @@ jobs: ("${{ needs.build-validator.result }}" == "success" || "${{ needs.build-validator.result }}" == "skipped") && \ ("${{ needs.build-miner.result }}" == "success" || "${{ needs.build-miner.result }}" == "skipped") && \ ("${{ needs.build-cli.result }}" == "success" || "${{ needs.build-cli.result }}" == "skipped") && \ - ("${{ needs.test-python-sdk.result }}" == "success" || "${{ needs.test-python-sdk.result }}" == "skipped") ]]; then + ("${{ needs.test-python-sdk.result }}" == "success" || "${{ needs.test-python-sdk.result }}" == "skipped") && \ + ("${{ needs.test-rust-sdk.result }}" == "success" || "${{ needs.test-rust-sdk.result }}" == "skipped") ]]; then echo "All CI checks passed!" exit 0 else diff --git a/Cargo.lock b/Cargo.lock index 796dc08e..89d3ae6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1621,6 +1621,29 @@ dependencies = [ "tokio", ] +[[package]] +name = "basilica-sdk-rust" +version = "0.17.0" +dependencies = [ + "base64 0.21.7", + "basilica-common", + "basilica-sdk", + "chrono", + "once_cell", + "regex", + "reqwest 0.11.27", + "serde", + "serde_json", + "shellexpand", + "tempfile", + "thiserror 1.0.69", + "tokio", + "tracing", + "url", + "uuid", + "wiremock", +] + [[package]] name = "basilica-validator" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 048d4469..6ba17fcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "crates/basilica-validator", "crates/basilica-cli", "crates/basilica-sdk", + "crates/basilica-sdk-rust", "crates/basilica-sdk-python", "crates/collateral-contract", ] diff --git a/crates/basilica-sdk-rust/Cargo.toml b/crates/basilica-sdk-rust/Cargo.toml new file mode 100644 index 00000000..2cf016fd --- /dev/null +++ b/crates/basilica-sdk-rust/Cargo.toml @@ -0,0 +1,46 @@ +[package] +name = "basilica-sdk-rust" +version = "0.17.0" +edition = "2021" +authors = ["Basilica Team"] +description = "High-level Rust SDK for the Basilica GPU cloud platform" +license = "MIT OR Apache-2.0" +readme = "README.md" + +[dependencies] +# Internal dependencies - reuse the low-level SDK +basilica-sdk = { path = "../basilica-sdk" } +basilica-common = { path = "../basilica-common" } + +# Async runtime +tokio = { workspace = true } + +# Serialization +serde = { workspace = true } +serde_json = { workspace = true } + +# Error handling +thiserror = { workspace = true } + +# Logging +tracing = { workspace = true } + +# Utilities +regex = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +once_cell = { workspace = true } +shellexpand = { workspace = true } +url = { workspace = true } +base64 = { workspace = true } + +# HTTP client for DNS/HTTP verification +reqwest = { workspace = true } + +[dev-dependencies] +wiremock = { workspace = true } +tokio = { workspace = true, features = ["full", "test-util"] } +tempfile = { workspace = true } + +[features] +default = [] diff --git a/crates/basilica-sdk-rust/src/client.rs b/crates/basilica-sdk-rust/src/client.rs new file mode 100644 index 00000000..c28594e3 --- /dev/null +++ b/crates/basilica-sdk-rust/src/client.rs @@ -0,0 +1,1134 @@ +//! High-level client for the Basilica GPU cloud platform. + +use crate::{ + deployment::Deployment, + error::{BasilicaError, DeploymentError, Result}, + model_size::estimate_gpu_requirements, +}; +use basilica_sdk::{ + BalanceResponse, CpuOffering, CreateDeploymentRequest, GpuOffering, GpuRequirementsSpec, + HealthCheckConfig, HealthCheckResponse, ListAvailableNodesQuery, ProbeConfig, + ResourceRequirements, SshKeyResponse, StorageBackend, StorageSpec, TopologySpreadConfig, + UsageHistoryResponse, +}; +use basilica_sdk::{BasilicaClient as LowLevelClient, ClientBuilder as LowLevelClientBuilder}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use uuid::Uuid; + +// Re-export types from low-level SDK for convenience +pub use basilica_sdk::types::{ + AvailableNode, ListSecureCloudRentalsResponse, PersistentStorageSpec, + SecureCloudRentalResponse, StartSecureCloudRentalRequest, StopSecureCloudRentalResponse, +}; + +/// Default Python image for source deployments. +pub const DEFAULT_PYTHON_IMAGE: &str = "python:3.11-slim"; + +/// Default API URL. +pub use basilica_sdk::client::DEFAULT_API_URL; + +/// Build default health check config for inference servers. +fn build_inference_health_check(port: u16) -> HealthCheckConfig { + HealthCheckConfig { + liveness: Some(ProbeConfig { + path: "/health".to_string(), + port: Some(port), + initial_delay_seconds: 60, + period_seconds: 30, + timeout_seconds: 10, + failure_threshold: 3, + }), + readiness: Some(ProbeConfig { + path: "/health".to_string(), + port: Some(port), + initial_delay_seconds: 30, + period_seconds: 10, + timeout_seconds: 5, + failure_threshold: 3, + }), + startup: Some(ProbeConfig { + path: "/health".to_string(), + port: Some(port), + initial_delay_seconds: 0, + period_seconds: 10, + timeout_seconds: 5, + failure_threshold: 60, + }), + } +} + +/// High-level client for deploying and managing applications on Basilica. +/// +/// This client wraps the low-level `basilica_sdk::BasilicaClient` and provides +/// a high-level, ergonomic API for common operations. +/// +/// # Example +/// +/// ```no_run +/// use basilica_sdk_rust::{BasilicaClient, VllmConfig}; +/// +/// # async fn example() -> Result<(), Box> { +/// let client = BasilicaClient::builder() +/// .with_api_key("your-api-key") +/// .build()?; +/// +/// // Deploy a vLLM inference server +/// let deployment = client.deploy_vllm(VllmConfig { +/// model: "meta-llama/Llama-3.1-8B-Instruct".to_string(), +/// ..Default::default() +/// }).await?; +/// +/// println!("Deployment ready at: {}", deployment.url()); +/// # Ok(()) +/// # } +/// ``` +pub struct BasilicaClient { + inner: Arc, + base_url: String, +} + +impl BasilicaClient { + /// Create a new client builder. + /// + /// # Example + /// + /// ```no_run + /// use basilica_sdk_rust::BasilicaClient; + /// + /// let client = BasilicaClient::builder() + /// .with_api_key("your-api-key") + /// .build() + /// .expect("Failed to create client"); + /// ``` + pub fn builder() -> BasilicaClientBuilder { + BasilicaClientBuilder::default() + } + + /// Get the base URL. + pub fn base_url(&self) -> &str { + &self.base_url + } + + // ========================================================================= + // High-Level Deployment API + // ========================================================================= + + /// Deploy an application to Basilica. + /// + /// This is the recommended high-level method for deploying applications. + /// It creates the deployment, waits for it to be ready, and returns a + /// Deployment object with convenient methods. + /// + /// # Arguments + /// + /// * `config` - Deployment configuration. + /// + /// # Returns + /// + /// A `Deployment` object representing the deployed application. + /// + /// # Errors + /// + /// Returns an error if the deployment fails or times out. + pub async fn deploy(&self, config: DeployConfig) -> Result { + let request = self.build_deploy_request(&config)?; + + let response = self.inner.create_deployment(request).await?; + + let mut deployment = Deployment::from_response(Arc::clone(&self.inner), response); + + deployment.wait_until_ready(config.timeout).await?; + deployment.refresh().await?; + + Ok(deployment) + } + + /// Deploy a vLLM inference server. + /// + /// This method deploys a vLLM server with the specified model and configuration. + /// It automatically estimates GPU requirements based on the model size. + /// + /// # Arguments + /// + /// * `config` - vLLM server configuration. + /// + /// # Returns + /// + /// A `Deployment` object representing the deployed vLLM server. + /// + /// # Errors + /// + /// Returns an error if the deployment fails or times out. + /// + /// # Example + /// + /// ```no_run + /// use basilica_sdk_rust::{BasilicaClient, VllmConfig}; + /// + /// # async fn example() -> Result<(), Box> { + /// let client = BasilicaClient::builder() + /// .with_api_key("key") + /// .build()?; + /// + /// let deployment = client.deploy_vllm(VllmConfig { + /// model: "Qwen/Qwen3-0.6B".to_string(), + /// trust_remote_code: true, + /// ..Default::default() + /// }).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn deploy_vllm(&self, config: VllmConfig) -> Result { + let reqs = estimate_gpu_requirements(&config.model); + let gpu_count = config.gpu_count.unwrap_or(reqs.gpu_count); + + let name = config.name.unwrap_or_else(|| { + let model_part = config + .model + .split('/') + .next_back() + .unwrap_or(&config.model) + .to_lowercase(); + let model_part: String = model_part + .chars() + .map(|c| if c.is_alphanumeric() { c } else { '-' }) + .take(40) + .collect(); + let model_part = model_part.trim_matches('-'); + format!("vllm-{}-{}", model_part, &Uuid::new_v4().to_string()[..8]) + }); + + let mut args = vec![ + "serve".to_string(), + config.model.clone(), + "--host".to_string(), + "0.0.0.0".to_string(), + "--port".to_string(), + "8000".to_string(), + ]; + + if let Some(tp) = config.tensor_parallel_size { + args.extend(["--tensor-parallel-size".to_string(), tp.to_string()]); + } + if let Some(max_len) = config.max_model_len { + args.extend(["--max-model-len".to_string(), max_len.to_string()]); + } + if let Some(ref dtype) = config.dtype { + args.extend(["--dtype".to_string(), dtype.clone()]); + } + if let Some(ref quant) = config.quantization { + args.extend(["--quantization".to_string(), quant.clone()]); + } + if let Some(ref served_name) = config.served_model_name { + args.extend(["--served-model-name".to_string(), served_name.clone()]); + } + if let Some(ref api_key) = config.api_key { + args.extend(["--api-key".to_string(), api_key.clone()]); + } + if let Some(gpu_util) = config.gpu_memory_utilization { + args.extend(["--gpu-memory-utilization".to_string(), gpu_util.to_string()]); + } + if config.enforce_eager { + args.push("--enforce-eager".to_string()); + } + if config.trust_remote_code { + args.push("--trust-remote-code".to_string()); + } + + let storage_spec = if config.storage { + Some(StorageSpec { + persistent: Some(PersistentStorageSpec { + enabled: true, + backend: StorageBackend::R2, + bucket: String::new(), + region: None, + endpoint: None, + credentials_secret: Some("basilica-r2-credentials".to_string()), + sync_interval_ms: 1000, + cache_size_mb: 4096, + mount_path: "/root/.cache".to_string(), + }), + }) + } else { + None + }; + + let gpu_spec = Some(GpuRequirementsSpec { + count: gpu_count, + model: config.gpu_models.clone().unwrap_or_default(), + min_cuda_version: None, + min_gpu_memory_gb: Some(reqs.memory_gb), + }); + + let resources = Some(ResourceRequirements { + cpu: "4".to_string(), + memory: config.memory.clone(), + cpu_request: None, + memory_request: None, + gpus: gpu_spec, + }); + + let health_check = config + .health_check + .clone() + .or_else(|| Some(build_inference_health_check(8000))); + + let request = CreateDeploymentRequest { + instance_name: name.clone(), + image: "vllm/vllm-openai:latest".to_string(), + replicas: 1, + port: 8000, + command: Some(vec!["vllm".to_string()]), + args: Some(args), + env: config.env.clone(), + resources, + ttl_seconds: config.ttl_seconds, + public: true, + storage: storage_spec, + health_check, + enable_billing: true, + queue_name: None, + suspended: false, + priority: None, + topology_spread: None, + }; + + let response = self.inner.create_deployment(request).await?; + let mut deployment = Deployment::from_response(Arc::clone(&self.inner), response); + + deployment.wait_until_ready(config.timeout).await?; + deployment.refresh().await?; + + Ok(deployment) + } + + /// Deploy an SGLang inference server. + /// + /// This method deploys an SGLang server with the specified model and configuration. + /// It automatically estimates GPU requirements based on the model size. + /// + /// # Arguments + /// + /// * `config` - SGLang server configuration. + /// + /// # Returns + /// + /// A `Deployment` object representing the deployed SGLang server. + /// + /// # Errors + /// + /// Returns an error if the deployment fails or times out. + pub async fn deploy_sglang(&self, config: SglangConfig) -> Result { + let reqs = estimate_gpu_requirements(&config.model); + let gpu_count = config.gpu_count.unwrap_or(reqs.gpu_count); + + let name = config.name.unwrap_or_else(|| { + let model_part = config + .model + .split('/') + .next_back() + .unwrap_or(&config.model) + .to_lowercase(); + let model_part: String = model_part + .chars() + .map(|c| if c.is_alphanumeric() { c } else { '-' }) + .take(40) + .collect(); + let model_part = model_part.trim_matches('-'); + format!("sglang-{}-{}", model_part, &Uuid::new_v4().to_string()[..8]) + }); + + let mut args = vec![ + "-m".to_string(), + "sglang.launch_server".to_string(), + "--model-path".to_string(), + config.model.clone(), + "--host".to_string(), + "0.0.0.0".to_string(), + "--port".to_string(), + "30000".to_string(), + ]; + + if let Some(tp) = config.tensor_parallel_size { + args.extend(["--tp".to_string(), tp.to_string()]); + } + if let Some(ctx_len) = config.context_length { + args.extend(["--context-length".to_string(), ctx_len.to_string()]); + } + if let Some(ref quant) = config.quantization { + args.extend(["--quantization".to_string(), quant.clone()]); + } + if let Some(mem_frac) = config.mem_fraction_static { + args.extend(["--mem-fraction-static".to_string(), mem_frac.to_string()]); + } + if config.trust_remote_code { + args.push("--trust-remote-code".to_string()); + } + + let storage_spec = if config.storage { + Some(StorageSpec { + persistent: Some(PersistentStorageSpec { + enabled: true, + backend: StorageBackend::R2, + bucket: String::new(), + region: None, + endpoint: None, + credentials_secret: Some("basilica-r2-credentials".to_string()), + sync_interval_ms: 1000, + cache_size_mb: 4096, + mount_path: "/root/.cache".to_string(), + }), + }) + } else { + None + }; + + let gpu_spec = Some(GpuRequirementsSpec { + count: gpu_count, + model: config.gpu_models.clone().unwrap_or_default(), + min_cuda_version: None, + min_gpu_memory_gb: Some(reqs.memory_gb), + }); + + let resources = Some(ResourceRequirements { + cpu: "4".to_string(), + memory: config.memory.clone(), + cpu_request: None, + memory_request: None, + gpus: gpu_spec, + }); + + let health_check = config + .health_check + .clone() + .or_else(|| Some(build_inference_health_check(30000))); + + let request = CreateDeploymentRequest { + instance_name: name.clone(), + image: "lmsysorg/sglang:latest".to_string(), + replicas: 1, + port: 30000, + command: Some(vec!["python".to_string()]), + args: Some(args), + env: config.env.clone(), + resources, + ttl_seconds: config.ttl_seconds, + public: true, + storage: storage_spec, + health_check, + enable_billing: true, + queue_name: None, + suspended: false, + priority: None, + topology_spread: None, + }; + + let response = self.inner.create_deployment(request).await?; + let mut deployment = Deployment::from_response(Arc::clone(&self.inner), response); + + deployment.wait_until_ready(config.timeout).await?; + deployment.refresh().await?; + + Ok(deployment) + } + + /// Get an existing deployment by name. + /// + /// # Arguments + /// + /// * `name` - The deployment instance name. + /// + /// # Returns + /// + /// A `Deployment` object if found. + /// + /// # Errors + /// + /// Returns `DeploymentError::NotFound` if the deployment doesn't exist. + pub async fn get(&self, name: &str) -> Result { + let response = self.inner.get_deployment(name).await.map_err(|e| { + if matches!(e, basilica_sdk::ApiError::NotFound { .. }) { + BasilicaError::Deployment(DeploymentError::NotFound { + instance_name: name.to_string(), + }) + } else { + BasilicaError::Api(e) + } + })?; + + Ok(Deployment::from_response(Arc::clone(&self.inner), response)) + } + + /// List all deployments. + /// + /// # Returns + /// + /// A vector of all deployments for the authenticated user. + /// + /// # Errors + /// + /// Returns an error if the API request fails. + pub async fn list(&self) -> Result> { + let response = self.inner.list_deployments().await?; + let mut deployments = Vec::new(); + + for summary in response.deployments { + match self.inner.get_deployment(&summary.instance_name).await { + Ok(full_response) => { + deployments.push(Deployment::from_response( + Arc::clone(&self.inner), + full_response, + )); + } + Err(basilica_sdk::ApiError::NotFound { .. }) => continue, + Err(e) => { + tracing::warn!( + "Failed to fetch deployment '{}': {}", + summary.instance_name, + e + ); + } + } + } + + Ok(deployments) + } + + // ========================================================================= + // Low-Level API Methods + // ========================================================================= + + /// Check API health status. + /// + /// # Returns + /// + /// Health check response with API status information. + pub async fn health_check(&self) -> Result { + Ok(self.inner.health_check().await?) + } + + /// List available compute nodes. + /// + /// # Arguments + /// + /// * `query` - Optional query parameters for filtering nodes. + /// + /// # Returns + /// + /// A vector of available compute nodes. + pub async fn list_nodes( + &self, + query: Option, + ) -> Result> { + let response = self.inner.list_available_nodes(query).await?; + Ok(response.available_nodes) + } + + /// Get account balance. + /// + /// # Returns + /// + /// The current account balance. + pub async fn get_balance(&self) -> Result { + Ok(self.inner.get_balance().await?) + } + + /// Get usage history. + /// + /// # Arguments + /// + /// * `limit` - Maximum number of records to return. + /// * `offset` - Number of records to skip. + /// + /// # Returns + /// + /// Usage history response with rental usage records. + pub async fn list_usage_history( + &self, + limit: Option, + offset: Option, + ) -> Result { + Ok(self.inner.list_usage_history(limit, offset).await?) + } + + // ========================================================================= + // SSH Key Management + // ========================================================================= + + /// Register an SSH key for secure cloud rentals. + /// + /// # Arguments + /// + /// * `name` - A name for the SSH key. + /// * `public_key` - The SSH public key content. + /// + /// # Returns + /// + /// The registered SSH key information. + pub async fn register_ssh_key(&self, name: &str, public_key: &str) -> Result { + Ok(self.inner.register_ssh_key(name, public_key).await?) + } + + /// Get the user's registered SSH key. + /// + /// # Returns + /// + /// The SSH key if registered, or None. + pub async fn get_ssh_key(&self) -> Result> { + Ok(self.inner.get_user_ssh_key().await?) + } + + /// Delete the user's SSH key. + pub async fn delete_ssh_key(&self) -> Result<()> { + Ok(self.inner.delete_ssh_key().await?) + } + + // ========================================================================= + // Secure Cloud GPU Rentals + // ========================================================================= + + /// List available GPU offerings from secure cloud providers. + /// + /// # Returns + /// + /// A vector of available GPU offerings. + pub async fn list_secure_cloud_gpus(&self) -> Result> { + Ok(self.inner.list_secure_cloud_gpus().await?) + } + + /// Start a secure cloud GPU rental. + /// + /// # Arguments + /// + /// * `offering_id` - The GPU offering ID. + /// * `ssh_public_key_id` - Optional SSH key ID. If not provided, uses the registered key. + /// + /// # Returns + /// + /// The rental response with connection details. + /// + /// # Errors + /// + /// Returns an error if no SSH key is registered and none is provided. + pub async fn start_secure_cloud_rental( + &self, + offering_id: &str, + ssh_public_key_id: Option<&str>, + ) -> Result { + let ssh_key_id = match ssh_public_key_id { + Some(id) => id.to_string(), + None => { + let key = self + .get_ssh_key() + .await? + .ok_or_else(|| BasilicaError::Validation { + message: "No SSH key registered. Use register_ssh_key() first.".to_string(), + field: Some("ssh_public_key_id".to_string()), + value: None, + })?; + key.id + } + }; + + let request = StartSecureCloudRentalRequest { + offering_id: offering_id.to_string(), + ssh_public_key_id: ssh_key_id, + }; + + Ok(self.inner.start_secure_cloud_rental(request).await?) + } + + /// Stop a secure cloud GPU rental. + /// + /// # Arguments + /// + /// * `rental_id` - The rental ID to stop. + /// + /// # Returns + /// + /// The stop response with final cost information. + pub async fn stop_secure_cloud_rental( + &self, + rental_id: &str, + ) -> Result { + Ok(self.inner.stop_secure_cloud_rental(rental_id).await?) + } + + /// List secure cloud rentals. + /// + /// # Returns + /// + /// A list of all secure cloud rentals. + pub async fn list_secure_cloud_rentals(&self) -> Result { + Ok(self.inner.list_secure_cloud_rentals().await?) + } + + // ========================================================================= + // CPU Rentals + // ========================================================================= + + /// List CPU-only offerings. + /// + /// # Returns + /// + /// A vector of available CPU offerings. + pub async fn list_cpu_offerings(&self) -> Result> { + Ok(self.inner.list_cpu_offerings().await?) + } + + /// Start a CPU-only rental. + /// + /// # Arguments + /// + /// * `offering_id` - The CPU offering ID. + /// * `ssh_public_key_id` - Optional SSH key ID. If not provided, uses the registered key. + /// + /// # Returns + /// + /// The rental response with connection details. + pub async fn start_cpu_rental( + &self, + offering_id: &str, + ssh_public_key_id: Option<&str>, + ) -> Result { + let ssh_key_id = match ssh_public_key_id { + Some(id) => id.to_string(), + None => { + let key = self + .get_ssh_key() + .await? + .ok_or_else(|| BasilicaError::Validation { + message: "No SSH key registered. Use register_ssh_key() first.".to_string(), + field: Some("ssh_public_key_id".to_string()), + value: None, + })?; + key.id + } + }; + + let request = StartSecureCloudRentalRequest { + offering_id: offering_id.to_string(), + ssh_public_key_id: ssh_key_id, + }; + + Ok(self.inner.start_cpu_rental(request).await?) + } + + /// Stop a CPU rental. + /// + /// # Arguments + /// + /// * `rental_id` - The rental ID to stop. + /// + /// # Returns + /// + /// The stop response with final cost information. + pub async fn stop_cpu_rental(&self, rental_id: &str) -> Result { + Ok(self.inner.stop_cpu_rental(rental_id).await?) + } + + // ========================================================================= + // Private Helpers + // ========================================================================= + + fn build_deploy_request(&self, config: &DeployConfig) -> Result { + let command = if let Some(ref source) = config.source { + let packager = crate::source::SourcePackager::new(source)?; + Some(packager.build_command(config.pip_packages.as_deref())) + } else { + None + }; + + let storage_spec = match &config.storage { + StorageConfig::None => None, + StorageConfig::Enabled => Some(StorageSpec { + persistent: Some(PersistentStorageSpec { + enabled: true, + backend: StorageBackend::R2, + bucket: String::new(), + region: None, + endpoint: None, + credentials_secret: None, + sync_interval_ms: 1000, + cache_size_mb: 1024, + mount_path: "/data".to_string(), + }), + }), + StorageConfig::CustomPath(path) => Some(StorageSpec { + persistent: Some(PersistentStorageSpec { + enabled: true, + backend: StorageBackend::R2, + bucket: String::new(), + region: None, + endpoint: None, + credentials_secret: None, + sync_interval_ms: 1000, + cache_size_mb: 1024, + mount_path: path.clone(), + }), + }), + }; + + let gpu_spec = config.gpu_count.map(|count| GpuRequirementsSpec { + count, + model: config.gpu_models.clone().unwrap_or_default(), + min_cuda_version: config.min_cuda_version.clone(), + min_gpu_memory_gb: config.min_gpu_memory_gb, + }); + + let resources = Some(ResourceRequirements { + cpu: config.cpu.clone(), + memory: config.memory.clone(), + cpu_request: None, + memory_request: None, + gpus: gpu_spec, + }); + + Ok(CreateDeploymentRequest { + instance_name: config.name.clone(), + image: config.image.clone(), + replicas: config.replicas, + port: config.port, + command, + args: None, + env: config.env.clone(), + resources, + ttl_seconds: config.ttl_seconds, + public: config.public, + storage: storage_spec, + health_check: config.health_check.clone(), + enable_billing: true, + queue_name: None, + suspended: false, + priority: None, + topology_spread: config.topology_spread.clone(), + }) + } +} + +/// Builder for BasilicaClient. +/// +/// Provides a fluent interface for configuring and building a `BasilicaClient`. +/// +/// # Example +/// +/// ```no_run +/// use basilica_sdk_rust::BasilicaClient; +/// use std::time::Duration; +/// +/// let client = BasilicaClient::builder() +/// .base_url("https://api.basilica.ai") +/// .with_api_key("your-api-key") +/// .timeout(Duration::from_secs(120)) +/// .build() +/// .expect("Failed to create client"); +/// ``` +#[derive(Default)] +pub struct BasilicaClientBuilder { + base_url: Option, + api_key: Option, + access_token: Option, + refresh_token: Option, + use_file_auth: bool, + timeout: Option, +} + +impl BasilicaClientBuilder { + /// Set the base URL. + /// + /// Defaults to `https://api.basilica.ai` if not specified. + pub fn base_url(mut self, url: impl Into) -> Self { + self.base_url = Some(url.into()); + self + } + + /// Set API key authentication. + /// + /// This is the recommended authentication method for programmatic access. + pub fn with_api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Set token-based authentication. + /// + /// Use this method when you have both an access token and refresh token. + pub fn with_tokens( + mut self, + access_token: impl Into, + refresh_token: impl Into, + ) -> Self { + self.access_token = Some(access_token.into()); + self.refresh_token = Some(refresh_token.into()); + self + } + + /// Use file-based authentication. + /// + /// Reads tokens from the standard Basilica CLI token file location. + pub fn with_file_auth(mut self) -> Self { + self.use_file_auth = true; + self + } + + /// Set request timeout. + /// + /// Defaults to 1200 seconds (20 minutes) if not specified. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Build the client. + /// + /// # Returns + /// + /// A configured `BasilicaClient` instance. + /// + /// # Errors + /// + /// Returns an error if no authentication method was provided. + pub fn build(self) -> Result { + let base_url = self.base_url.unwrap_or_else(|| DEFAULT_API_URL.to_string()); + + let mut builder = LowLevelClientBuilder::new().base_url(&base_url); + + if let Some(api_key) = self.api_key { + builder = builder.with_api_key(&api_key); + } else if let (Some(access), Some(refresh)) = (self.access_token, self.refresh_token) { + builder = builder.with_tokens(access, refresh); + } else if self.use_file_auth { + builder = builder.with_file_auth(); + } else { + return Err(BasilicaError::Authentication { + message: "No authentication method provided. Use with_api_key(), with_tokens(), or with_file_auth()".to_string(), + }); + } + + if let Some(timeout) = self.timeout { + builder = builder.timeout(timeout); + } + + let inner = builder.build().map_err(BasilicaError::Api)?; + + Ok(BasilicaClient { + inner: Arc::new(inner), + base_url, + }) + } +} + +/// Configuration for deploy(). +/// +/// Contains all configuration options for deploying a general application. +#[derive(Debug, Clone)] +pub struct DeployConfig { + /// The deployment instance name (must be unique). + pub name: String, + /// Optional Python source code to deploy. + pub source: Option, + /// Container image to use. + pub image: String, + /// Port to expose. + pub port: u32, + /// Environment variables. + pub env: Option>, + /// CPU resource limit (e.g., "500m", "2"). + pub cpu: String, + /// Memory resource limit (e.g., "512Mi", "2Gi"). + pub memory: String, + /// Storage configuration. + pub storage: StorageConfig, + /// Number of GPUs to request. + pub gpu_count: Option, + /// GPU model names to request (e.g., ["H100", "A100"]). + pub gpu_models: Option>, + /// Minimum CUDA version required. + pub min_cuda_version: Option, + /// Minimum GPU memory in GB. + pub min_gpu_memory_gb: Option, + /// Number of replicas. + pub replicas: u32, + /// Time-to-live in seconds (auto-delete after this time). + pub ttl_seconds: Option, + /// Whether the deployment is publicly accessible. + pub public: bool, + /// Timeout in seconds for waiting until ready. + pub timeout: u64, + /// Pip packages to install (for source deployments). + pub pip_packages: Option>, + /// Topology spreading configuration for replicas. + pub topology_spread: Option, + /// Health check configuration. + pub health_check: Option, +} + +impl Default for DeployConfig { + fn default() -> Self { + Self { + name: String::new(), + source: None, + image: DEFAULT_PYTHON_IMAGE.to_string(), + port: 8000, + env: None, + cpu: "500m".to_string(), + memory: "512Mi".to_string(), + storage: StorageConfig::None, + gpu_count: None, + gpu_models: None, + min_cuda_version: None, + min_gpu_memory_gb: None, + replicas: 1, + ttl_seconds: None, + public: true, + timeout: 300, + pip_packages: None, + topology_spread: None, + health_check: None, + } + } +} + +/// Storage configuration for deployments. +#[derive(Debug, Clone, Default)] +pub enum StorageConfig { + /// No persistent storage. + #[default] + None, + /// Enable storage with default mount path (/data). + Enabled, + /// Enable storage with a custom mount path. + CustomPath(String), +} + +/// Configuration for deploy_vllm(). +/// +/// Contains all configuration options for deploying a vLLM inference server. +#[derive(Debug, Clone)] +pub struct VllmConfig { + /// The model to serve (e.g., "meta-llama/Llama-3.1-8B-Instruct"). + pub model: String, + /// Optional deployment name (auto-generated if not provided). + pub name: Option, + /// Number of GPUs (auto-estimated from model if not provided). + pub gpu_count: Option, + /// GPU model names to request. + pub gpu_models: Option>, + /// Memory resource limit. + pub memory: String, + /// Enable persistent storage for model cache. + pub storage: bool, + /// Tensor parallel size (for multi-GPU). + pub tensor_parallel_size: Option, + /// Maximum model context length. + pub max_model_len: Option, + /// Data type (e.g., "auto", "float16", "bfloat16"). + pub dtype: Option, + /// Quantization method (e.g., "awq", "gptq"). + pub quantization: Option, + /// Model name to serve as (for API compatibility). + pub served_model_name: Option, + /// API key for the vLLM server. + pub api_key: Option, + /// GPU memory utilization fraction (0.0-1.0). + pub gpu_memory_utilization: Option, + /// Disable CUDA graphs for debugging. + pub enforce_eager: bool, + /// Trust remote code from model repository. + pub trust_remote_code: bool, + /// Additional environment variables. + pub env: Option>, + /// Time-to-live in seconds. + pub ttl_seconds: Option, + /// Timeout for waiting until ready. + pub timeout: u64, + /// Custom health check configuration. + pub health_check: Option, +} + +impl Default for VllmConfig { + fn default() -> Self { + Self { + model: "Qwen/Qwen3-0.6B".to_string(), + name: None, + gpu_count: None, + gpu_models: None, + memory: "16Gi".to_string(), + storage: true, + tensor_parallel_size: None, + max_model_len: None, + dtype: None, + quantization: None, + served_model_name: None, + api_key: None, + gpu_memory_utilization: None, + enforce_eager: false, + trust_remote_code: false, + env: None, + ttl_seconds: None, + timeout: 600, + health_check: None, + } + } +} + +/// Configuration for deploy_sglang(). +/// +/// Contains all configuration options for deploying an SGLang inference server. +#[derive(Debug, Clone)] +pub struct SglangConfig { + /// The model to serve (e.g., "Qwen/Qwen2.5-0.5B-Instruct"). + pub model: String, + /// Optional deployment name (auto-generated if not provided). + pub name: Option, + /// Number of GPUs (auto-estimated from model if not provided). + pub gpu_count: Option, + /// GPU model names to request. + pub gpu_models: Option>, + /// Memory resource limit. + pub memory: String, + /// Enable persistent storage for model cache. + pub storage: bool, + /// Tensor parallel size (for multi-GPU). + pub tensor_parallel_size: Option, + /// Context length for the model. + pub context_length: Option, + /// Quantization method. + pub quantization: Option, + /// Memory fraction for static allocation. + pub mem_fraction_static: Option, + /// Trust remote code from model repository. + pub trust_remote_code: bool, + /// Additional environment variables. + pub env: Option>, + /// Time-to-live in seconds. + pub ttl_seconds: Option, + /// Timeout for waiting until ready. + pub timeout: u64, + /// Custom health check configuration. + pub health_check: Option, +} + +impl Default for SglangConfig { + fn default() -> Self { + Self { + model: "Qwen/Qwen2.5-0.5B-Instruct".to_string(), + name: None, + gpu_count: None, + gpu_models: None, + memory: "16Gi".to_string(), + storage: true, + tensor_parallel_size: None, + context_length: None, + quantization: None, + mem_fraction_static: None, + trust_remote_code: false, + env: None, + ttl_seconds: None, + timeout: 600, + health_check: None, + } + } +} diff --git a/crates/basilica-sdk-rust/src/deployment.rs b/crates/basilica-sdk-rust/src/deployment.rs new file mode 100644 index 00000000..20abdbdf --- /dev/null +++ b/crates/basilica-sdk-rust/src/deployment.rs @@ -0,0 +1,508 @@ +//! Deployment facade for managing deployments. +//! +//! This module provides the `Deployment` struct which represents a deployed +//! application on Basilica and provides convenient methods for managing it. + +use crate::error::{BasilicaError, DeploymentError, Result}; +use basilica_sdk::{BasilicaClient as LowLevelClient, DeploymentResponse}; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; +use url::Url; + +/// HTTP readiness verification timeout. +const HTTP_READY_TIMEOUT: Duration = Duration::from_secs(10); + +/// Progress information for long-running operations. +#[derive(Debug, Clone)] +pub struct ProgressInfo { + /// Number of bytes synced (for storage sync). + pub bytes_synced: Option, + /// Total bytes to sync. + pub bytes_total: Option, + /// Progress percentage (0-100). + pub percentage: Option, + /// Current step description. + pub current_step: String, + /// Timestamp when the operation started. + pub started_at: String, + /// Elapsed time in seconds. + pub elapsed_seconds: u64, +} + +/// Current status of a deployment. +#[derive(Debug, Clone)] +pub struct DeploymentStatus { + /// The deployment state (e.g., "Pending", "Active", "Running", "Failed"). + pub state: String, + /// Number of replicas that are ready. + pub replicas_ready: u32, + /// Number of replicas that are desired. + pub replicas_desired: u32, + /// Optional status message. + pub message: Option, + /// Current deployment phase. + pub phase: Option, + /// Progress information for long-running operations. + pub progress: Option, +} + +impl DeploymentStatus { + /// Check if the deployment is fully ready. + /// + /// A deployment is ready when: + /// - State is "Active" or "Running" + /// - All desired replicas are ready + /// - At least one replica is running + pub fn is_ready(&self) -> bool { + matches!(self.state.as_str(), "Active" | "Running") + && self.replicas_ready == self.replicas_desired + && self.replicas_ready > 0 + } + + /// Check if the deployment has failed. + pub fn is_failed(&self) -> bool { + self.state == "Failed" || self.phase.as_deref() == Some("failed") + } + + /// Check if the deployment is still starting. + pub fn is_pending(&self) -> bool { + matches!(self.state.as_str(), "Pending" | "Provisioning") + || matches!( + self.phase.as_deref(), + Some("pending") + | Some("scheduling") + | Some("pulling") + | Some("initializing") + | Some("storage_sync") + | Some("starting") + | Some("health_check") + ) + } +} + +/// A facade for managing a Basilica deployment. +/// +/// This struct provides a convenient interface for interacting with a deployed +/// application, including checking status, getting logs, and deleting. +/// +/// # Example +/// +/// ```no_run +/// use basilica_sdk_rust::BasilicaClient; +/// +/// # async fn example() -> Result<(), Box> { +/// let client = BasilicaClient::builder() +/// .with_api_key("key") +/// .build()?; +/// +/// let deployment = client.get("my-deployment").await?; +/// println!("URL: {}", deployment.url()); +/// +/// let status = deployment.status().await?; +/// println!("State: {}, Replicas: {}/{}", +/// status.state, status.replicas_ready, status.replicas_desired); +/// +/// let logs = deployment.logs(Some(100)).await?; +/// println!("Logs:\n{}", logs); +/// +/// deployment.delete().await?; +/// # Ok(()) +/// # } +/// ``` +pub struct Deployment { + client: Arc, + name: String, + url: String, + namespace: String, + user_id: String, + state: String, + created_at: String, + updated_at: Option, + replicas_ready: u32, + replicas_desired: u32, +} + +impl Deployment { + /// Create from API response. + pub(crate) fn from_response(client: Arc, response: DeploymentResponse) -> Self { + Self { + client, + name: response.instance_name, + url: response.url, + namespace: response.namespace, + user_id: response.user_id, + state: response.state, + created_at: response.created_at, + updated_at: response.updated_at, + replicas_ready: response.replicas.ready, + replicas_desired: response.replicas.desired, + } + } + + /// The deployment instance name. + pub fn name(&self) -> &str { + &self.name + } + + /// The public URL for accessing the deployment. + pub fn url(&self) -> &str { + &self.url + } + + /// The Kubernetes namespace. + pub fn namespace(&self) -> &str { + &self.namespace + } + + /// The owner's user ID. + pub fn user_id(&self) -> &str { + &self.user_id + } + + /// The last known deployment state. + pub fn state(&self) -> &str { + &self.state + } + + /// Creation timestamp. + pub fn created_at(&self) -> &str { + &self.created_at + } + + /// Last update timestamp, if available. + pub fn updated_at(&self) -> Option<&str> { + self.updated_at.as_deref() + } + + /// Number of ready replicas (last known). + pub fn replicas_ready(&self) -> u32 { + self.replicas_ready + } + + /// Number of desired replicas (last known). + pub fn replicas_desired(&self) -> u32 { + self.replicas_desired + } + + /// Get the current deployment status from the API. + /// + /// This method fetches the latest status information from the API. + /// + /// # Returns + /// + /// The current deployment status. + pub async fn status(&self) -> Result { + let response = self.client.get_deployment(&self.name).await?; + + let progress = response.progress.map(|p| ProgressInfo { + bytes_synced: None, + bytes_total: None, + percentage: p.percentage, + current_step: p.current_step, + started_at: String::new(), + elapsed_seconds: p.elapsed_seconds, + }); + + Ok(DeploymentStatus { + state: response.state, + replicas_ready: response.replicas.ready, + replicas_desired: response.replicas.desired, + message: response.message, + phase: response.phase, + progress, + }) + } + + /// Get deployment logs. + /// + /// # Arguments + /// + /// * `tail` - Optional number of lines to return from the end. + /// + /// # Returns + /// + /// The deployment logs as a string. + pub async fn logs(&self, tail: Option) -> Result { + let response = self + .client + .get_deployment_logs(&self.name, false, tail) + .await?; + response.text().await.map_err(|e| BasilicaError::Network { + message: format!("Failed to read logs: {}", e), + source: Some(Box::new(e)), + }) + } + + /// Wait for the deployment to become ready. + /// + /// This method polls the deployment status until it becomes ready, + /// fails, or times out. + /// + /// # Arguments + /// + /// * `timeout_secs` - Maximum time to wait in seconds. + /// + /// # Returns + /// + /// The final deployment status when ready. + /// + /// # Errors + /// + /// Returns an error if the deployment fails or times out. + pub async fn wait_until_ready(&mut self, timeout_secs: u64) -> Result { + self.wait_until_ready_with_callback(timeout_secs, None::) + .await + } + + /// Wait for the deployment to become ready with a progress callback. + /// + /// # Arguments + /// + /// * `timeout_secs` - Maximum time to wait in seconds. + /// * `on_progress` - Optional callback invoked when the deployment phase changes. + /// + /// # Returns + /// + /// The final deployment status when ready. + /// + /// # Errors + /// + /// Returns an error if the deployment fails or times out. + pub async fn wait_until_ready_with_callback( + &mut self, + timeout_secs: u64, + on_progress: Option, + ) -> Result + where + F: Fn(DeploymentStatus), + { + let poll_interval = Duration::from_secs(5); + let timeout = Duration::from_secs(timeout_secs); + let start = std::time::Instant::now(); + let mut last_phase: Option = None; + let mut last_status: Option = None; + + while start.elapsed() < timeout { + let status = self.status().await?; + + // Update cached state + self.state = status.state.clone(); + self.replicas_ready = status.replicas_ready; + self.replicas_desired = status.replicas_desired; + + // Call progress callback if phase changed + if let Some(ref callback) = on_progress { + if last_phase.as_deref() != status.phase.as_deref() { + callback(status.clone()); + } + } + + last_phase = status.phase.clone(); + last_status = Some(status.clone()); + + if status.is_ready() { + // Verify DNS resolution and HTTP endpoint + if let Some(hostname) = self.get_hostname() { + if !self.is_dns_resolvable(&hostname).await { + sleep(poll_interval).await; + continue; + } + if !self.is_http_ready().await { + sleep(poll_interval).await; + continue; + } + } + return Ok(status); + } + + if status.is_failed() { + return Err(BasilicaError::Deployment(DeploymentError::Failed { + instance_name: self.name.clone(), + reason: status.message, + })); + } + + // Dynamic sleep based on phase + let sleep_duration = match status.phase.as_deref() { + Some("scheduling") | Some("pulling") => Duration::from_secs(10), + Some("storage_sync") => Duration::from_secs(3), + _ => poll_interval, + }; + + sleep(sleep_duration).await; + } + + let last = last_status.unwrap_or_else(|| DeploymentStatus { + state: "Unknown".to_string(), + replicas_ready: 0, + replicas_desired: 1, + message: None, + phase: None, + progress: None, + }); + + Err(BasilicaError::Deployment(DeploymentError::Timeout { + instance_name: self.name.clone(), + timeout_seconds: timeout_secs, + last_state: last.state, + replicas_ready: last.replicas_ready, + replicas_desired: last.replicas_desired, + })) + } + + /// Delete the deployment. + /// + /// This method initiates deletion of the deployment and all associated resources. + pub async fn delete(&self) -> Result<()> { + self.client.delete_deployment(&self.name).await?; + Ok(()) + } + + /// Refresh the deployment data from the API. + /// + /// This method updates the cached deployment information with the latest + /// data from the API. + pub async fn refresh(&mut self) -> Result<()> { + let response = self.client.get_deployment(&self.name).await?; + self.url = response.url; + self.state = response.state; + self.replicas_ready = response.replicas.ready; + self.replicas_desired = response.replicas.desired; + self.updated_at = response.updated_at; + Ok(()) + } + + /// Extract hostname from URL. + fn get_hostname(&self) -> Option { + Url::parse(&self.url) + .ok() + .and_then(|u| u.host_str().map(|s| s.to_string())) + } + + /// Check if hostname resolves. + async fn is_dns_resolvable(&self, hostname: &str) -> bool { + tokio::net::lookup_host(format!("{}:80", hostname)) + .await + .is_ok() + } + + /// Check if HTTP endpoint is responding. + async fn is_http_ready(&self) -> bool { + let client = reqwest::Client::builder() + .timeout(HTTP_READY_TIMEOUT) + .danger_accept_invalid_certs(true) + .build(); + + let client = match client { + Ok(c) => c, + Err(_) => return false, + }; + + match client.head(&self.url).send().await { + Ok(_) => true, + Err(e) if e.is_status() => true, // HTTP error means server is responding + Err(_) => false, + } + } +} + +impl std::fmt::Debug for Deployment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Deployment") + .field("name", &self.name) + .field("state", &self.state) + .field("url", &self.url) + .field("namespace", &self.namespace) + .field("replicas_ready", &self.replicas_ready) + .field("replicas_desired", &self.replicas_desired) + .finish() + } +} + +impl std::fmt::Display for Deployment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Deployment '{}' ({}) at {}", + self.name, self.state, self.url + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deployment_status_is_ready() { + let status = DeploymentStatus { + state: "Active".to_string(), + replicas_ready: 1, + replicas_desired: 1, + message: None, + phase: Some("ready".to_string()), + progress: None, + }; + assert!(status.is_ready()); + assert!(!status.is_failed()); + assert!(!status.is_pending()); + } + + #[test] + fn test_deployment_status_is_failed() { + let status = DeploymentStatus { + state: "Failed".to_string(), + replicas_ready: 0, + replicas_desired: 1, + message: Some("Pod crashed".to_string()), + phase: Some("failed".to_string()), + progress: None, + }; + assert!(!status.is_ready()); + assert!(status.is_failed()); + assert!(!status.is_pending()); + } + + #[test] + fn test_deployment_status_is_pending() { + let status = DeploymentStatus { + state: "Pending".to_string(), + replicas_ready: 0, + replicas_desired: 1, + message: None, + phase: Some("scheduling".to_string()), + progress: None, + }; + assert!(!status.is_ready()); + assert!(!status.is_failed()); + assert!(status.is_pending()); + } + + #[test] + fn test_deployment_status_partial_ready() { + let status = DeploymentStatus { + state: "Active".to_string(), + replicas_ready: 1, + replicas_desired: 3, + message: None, + phase: None, + progress: None, + }; + // Not fully ready because not all replicas are up + assert!(!status.is_ready()); + } + + #[test] + fn test_deployment_status_running_state() { + let status = DeploymentStatus { + state: "Running".to_string(), + replicas_ready: 2, + replicas_desired: 2, + message: None, + phase: None, + progress: None, + }; + assert!(status.is_ready()); + } +} diff --git a/crates/basilica-sdk-rust/src/error.rs b/crates/basilica-sdk-rust/src/error.rs new file mode 100644 index 00000000..7917832e --- /dev/null +++ b/crates/basilica-sdk-rust/src/error.rs @@ -0,0 +1,360 @@ +//! Error types for the high-level Basilica Rust SDK. +//! +//! This module provides a comprehensive error hierarchy that mirrors the Python SDK's +//! error handling approach, with specific error types for common failure scenarios. + +use thiserror::Error; + +/// Base error for all Basilica SDK errors. +/// +/// This enum represents all possible errors that can occur when using the SDK. +/// It provides detailed error messages and context where available. +#[derive(Debug, Error)] +pub enum BasilicaError { + /// Authentication failed (invalid credentials, expired token, etc.) + #[error("Authentication failed: {message}")] + Authentication { + /// Description of the authentication failure + message: String, + }, + + /// Authorization denied (insufficient permissions) + #[error("Authorization denied: {message}")] + Authorization { + /// Description of the authorization failure + message: String, + /// The resource that was being accessed (if available) + resource: Option, + }, + + /// Validation error (invalid input parameters) + #[error("Validation error: {message}")] + Validation { + /// Description of the validation failure + message: String, + /// The field that failed validation (if applicable) + field: Option, + /// The invalid value (if applicable) + value: Option, + }, + + /// Deployment-specific error + #[error(transparent)] + Deployment(#[from] DeploymentError), + + /// Resource unavailable (capacity, quota, etc.) + #[error("Resource unavailable: {message}")] + Resource { + /// Description of the resource issue + message: String, + /// Type of resource (GPU, CPU, memory, etc.) + resource_type: Option, + }, + + /// Storage operation error + #[error("Storage error: {message}")] + Storage { + /// Description of the storage failure + message: String, + }, + + /// Network-related error + #[error("Network error: {message}")] + Network { + /// Description of the network issue + message: String, + /// Underlying error source (if available) + #[source] + source: Option>, + }, + + /// Rate limit exceeded + #[error("Rate limit exceeded")] + RateLimit { + /// Seconds to wait before retrying (if provided by API) + retry_after: Option, + }, + + /// Source code/file error + #[error("Source error: {message}")] + Source { + /// Description of the source error + message: String, + /// Path to the source file (if applicable) + source_path: Option, + }, + + /// Low-level API error from basilica-sdk + #[error("API error: {0}")] + Api(#[from] basilica_sdk::ApiError), +} + +/// Deployment-specific errors with detailed context. +/// +/// These errors occur during deployment creation, updates, or status monitoring. +#[derive(Debug, Error)] +pub enum DeploymentError { + /// Deployment not found + #[error("Deployment '{instance_name}' not found")] + NotFound { + /// Name of the deployment that was not found + instance_name: String, + }, + + /// Deployment timed out waiting to become ready + #[error("Deployment '{instance_name}' timed out after {timeout_seconds}s (state: {last_state}, replicas: {replicas_ready}/{replicas_desired})")] + Timeout { + /// Name of the deployment + instance_name: String, + /// How long we waited before timing out + timeout_seconds: u64, + /// Last observed state before timeout + last_state: String, + /// Number of replicas that were ready + replicas_ready: u32, + /// Number of replicas that were desired + replicas_desired: u32, + }, + + /// Deployment failed (container crash, image pull error, etc.) + #[error("Deployment '{instance_name}' failed: {}", reason.as_deref().unwrap_or("unknown reason"))] + Failed { + /// Name of the deployment + instance_name: String, + /// Reason for failure (if known) + reason: Option, + }, +} + +/// Result type alias for SDK operations. +pub type Result = std::result::Result; + +impl BasilicaError { + /// Create an authentication error with a custom message. + pub fn authentication(message: impl Into) -> Self { + Self::Authentication { + message: message.into(), + } + } + + /// Create an authorization error with optional resource context. + pub fn authorization(message: impl Into, resource: Option) -> Self { + Self::Authorization { + message: message.into(), + resource, + } + } + + /// Create a validation error with optional field and value context. + pub fn validation( + message: impl Into, + field: Option, + value: Option, + ) -> Self { + Self::Validation { + message: message.into(), + field, + value, + } + } + + /// Create a resource unavailable error. + pub fn resource(message: impl Into, resource_type: Option) -> Self { + Self::Resource { + message: message.into(), + resource_type, + } + } + + /// Create a storage error. + pub fn storage(message: impl Into) -> Self { + Self::Storage { + message: message.into(), + } + } + + /// Create a network error with an optional source error. + pub fn network( + message: impl Into, + source: Option>, + ) -> Self { + Self::Network { + message: message.into(), + source, + } + } + + /// Create a rate limit error with optional retry-after duration. + pub fn rate_limit(retry_after: Option) -> Self { + Self::RateLimit { retry_after } + } + + /// Create a source error with optional path context. + pub fn source(message: impl Into, source_path: Option) -> Self { + Self::Source { + message: message.into(), + source_path, + } + } + + /// Check if this error is retryable. + /// + /// Returns `true` for transient errors that may succeed if retried. + pub fn is_retryable(&self) -> bool { + match self { + Self::Network { .. } => true, + Self::RateLimit { .. } => true, + Self::Api(api_err) => api_err.is_retryable(), + Self::Deployment(DeploymentError::Timeout { .. }) => false, + _ => false, + } + } + + /// Check if this is a client error (user mistake vs server issue). + pub fn is_client_error(&self) -> bool { + match self { + Self::Authentication { .. } => true, + Self::Authorization { .. } => true, + Self::Validation { .. } => true, + Self::Source { .. } => true, + Self::Api(api_err) => api_err.is_client_error(), + _ => false, + } + } +} + +impl DeploymentError { + /// Create a not found error for a deployment. + pub fn not_found(instance_name: impl Into) -> Self { + Self::NotFound { + instance_name: instance_name.into(), + } + } + + /// Create a timeout error with detailed context. + pub fn timeout( + instance_name: impl Into, + timeout_seconds: u64, + last_state: impl Into, + replicas_ready: u32, + replicas_desired: u32, + ) -> Self { + Self::Timeout { + instance_name: instance_name.into(), + timeout_seconds, + last_state: last_state.into(), + replicas_ready, + replicas_desired, + } + } + + /// Create a failure error with optional reason. + pub fn failed(instance_name: impl Into, reason: Option) -> Self { + Self::Failed { + instance_name: instance_name.into(), + reason, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_authentication_error_display() { + let err = BasilicaError::authentication("Invalid API key"); + assert_eq!(err.to_string(), "Authentication failed: Invalid API key"); + assert!(err.is_client_error()); + assert!(!err.is_retryable()); + } + + #[test] + fn test_authorization_error_display() { + let err = BasilicaError::authorization( + "Insufficient permissions", + Some("deployment/my-app".to_string()), + ); + assert_eq!( + err.to_string(), + "Authorization denied: Insufficient permissions" + ); + assert!(err.is_client_error()); + } + + #[test] + fn test_validation_error_display() { + let err = BasilicaError::validation( + "Invalid port number", + Some("port".to_string()), + Some("-1".to_string()), + ); + assert_eq!(err.to_string(), "Validation error: Invalid port number"); + assert!(err.is_client_error()); + } + + #[test] + fn test_deployment_not_found_error() { + let err = DeploymentError::not_found("my-missing-app"); + assert_eq!(err.to_string(), "Deployment 'my-missing-app' not found"); + } + + #[test] + fn test_deployment_timeout_error() { + let err = DeploymentError::timeout("my-app", 300, "Pending", 0, 3); + assert_eq!( + err.to_string(), + "Deployment 'my-app' timed out after 300s (state: Pending, replicas: 0/3)" + ); + } + + #[test] + fn test_deployment_failed_error() { + let err = DeploymentError::failed("my-app", Some("ImagePullBackOff".to_string())); + assert_eq!( + err.to_string(), + "Deployment 'my-app' failed: ImagePullBackOff" + ); + } + + #[test] + fn test_network_error_retryable() { + let err = BasilicaError::network("Connection refused", None); + assert!(err.is_retryable()); + assert!(!err.is_client_error()); + } + + #[test] + fn test_rate_limit_error_retryable() { + let err = BasilicaError::rate_limit(Some(60)); + assert!(err.is_retryable()); + assert_eq!(err.to_string(), "Rate limit exceeded"); + } + + #[test] + fn test_source_error() { + let err = BasilicaError::source("File not found", Some("/path/to/app.py".to_string())); + assert_eq!(err.to_string(), "Source error: File not found"); + assert!(err.is_client_error()); + } + + #[test] + fn test_resource_error() { + let err = BasilicaError::resource("No GPUs available", Some("GPU".to_string())); + assert_eq!(err.to_string(), "Resource unavailable: No GPUs available"); + assert!(!err.is_client_error()); + } + + #[test] + fn test_storage_error() { + let err = BasilicaError::storage("Volume not found"); + assert_eq!(err.to_string(), "Storage error: Volume not found"); + } + + #[test] + fn test_deployment_error_conversion() { + let deployment_err = DeploymentError::not_found("test-app"); + let basilica_err: BasilicaError = deployment_err.into(); + assert!(matches!(basilica_err, BasilicaError::Deployment(_))); + } +} diff --git a/crates/basilica-sdk-rust/src/lib.rs b/crates/basilica-sdk-rust/src/lib.rs new file mode 100644 index 00000000..f7d05330 --- /dev/null +++ b/crates/basilica-sdk-rust/src/lib.rs @@ -0,0 +1,168 @@ +//! # Basilica SDK for Rust +//! +//! High-level Rust SDK for the Basilica GPU cloud platform. +//! +//! This crate provides an ergonomic, high-level API for deploying and managing +//! applications on Basilica. It wraps the low-level `basilica-sdk` crate with +//! convenient abstractions for common use cases. +//! +//! ## Features +//! +//! - **Easy deployment**: Deploy applications with minimal configuration +//! - **Inference servers**: One-line deployment of vLLM and SGLang servers +//! - **Auto GPU estimation**: Automatically estimate GPU requirements from model names +//! - **Status tracking**: Wait for deployments with progress callbacks +//! - **Secure cloud**: Rent GPUs from secure cloud providers +//! +//! ## Quick Start +//! +//! ```no_run +//! use basilica_sdk_rust::{BasilicaClient, VllmConfig}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! // Create a client with API key authentication +//! let client = BasilicaClient::builder() +//! .with_api_key("your-api-key") +//! .build()?; +//! +//! // Deploy a vLLM inference server +//! let deployment = client.deploy_vllm(VllmConfig { +//! model: "Qwen/Qwen3-0.6B".to_string(), +//! trust_remote_code: true, +//! ..Default::default() +//! }).await?; +//! +//! println!("Deployment ready at: {}", deployment.url()); +//! +//! // Get status +//! let status = deployment.status().await?; +//! println!("State: {}", status.state); +//! +//! // Get logs +//! let logs = deployment.logs(Some(50)).await?; +//! println!("Logs:\n{}", logs); +//! +//! // Delete when done +//! deployment.delete().await?; +//! +//! Ok(()) +//! } +//! ``` +//! +//! ## Authentication +//! +//! The SDK supports multiple authentication methods: +//! +//! ### API Key (Recommended) +//! +//! ```no_run +//! use basilica_sdk_rust::BasilicaClient; +//! +//! let client = BasilicaClient::builder() +//! .with_api_key("your-api-key") +//! .build() +//! .expect("Failed to create client"); +//! ``` +//! +//! ### Token-based +//! +//! ```no_run +//! use basilica_sdk_rust::BasilicaClient; +//! +//! let client = BasilicaClient::builder() +//! .with_tokens("access-token", "refresh-token") +//! .build() +//! .expect("Failed to create client"); +//! ``` +//! +//! ### File-based (CLI tokens) +//! +//! ```no_run +//! use basilica_sdk_rust::BasilicaClient; +//! +//! let client = BasilicaClient::builder() +//! .with_file_auth() +//! .build() +//! .expect("Failed to create client"); +//! ``` +//! +//! ## Deployment Types +//! +//! ### vLLM Inference Server +//! +//! ```no_run +//! use basilica_sdk_rust::{BasilicaClient, VllmConfig}; +//! +//! # async fn example() -> Result<(), Box> { +//! # let client = BasilicaClient::builder().with_api_key("key").build()?; +//! let deployment = client.deploy_vllm(VllmConfig { +//! model: "meta-llama/Llama-3.1-8B-Instruct".to_string(), +//! tensor_parallel_size: Some(2), +//! max_model_len: Some(4096), +//! ..Default::default() +//! }).await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ### SGLang Inference Server +//! +//! ```no_run +//! use basilica_sdk_rust::{BasilicaClient, SglangConfig}; +//! +//! # async fn example() -> Result<(), Box> { +//! # let client = BasilicaClient::builder().with_api_key("key").build()?; +//! let deployment = client.deploy_sglang(SglangConfig { +//! model: "Qwen/Qwen2.5-0.5B-Instruct".to_string(), +//! trust_remote_code: true, +//! ..Default::default() +//! }).await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Custom Container Deployment +//! +//! ```no_run +//! use basilica_sdk_rust::{BasilicaClient, DeployConfig, StorageConfig}; +//! +//! # async fn example() -> Result<(), Box> { +//! # let client = BasilicaClient::builder().with_api_key("key").build()?; +//! let deployment = client.deploy(DeployConfig { +//! name: "my-app".to_string(), +//! image: "nginx:latest".to_string(), +//! port: 80, +//! replicas: 2, +//! storage: StorageConfig::Enabled, +//! ..Default::default() +//! }).await?; +//! # Ok(()) +//! # } +//! ``` + +pub mod client; +pub mod deployment; +pub mod error; +pub mod model_size; +pub mod source; + +// Re-export main types for convenience +pub use client::{ + BasilicaClient, BasilicaClientBuilder, DeployConfig, SglangConfig, StorageConfig, VllmConfig, + DEFAULT_API_URL, DEFAULT_PYTHON_IMAGE, +}; +pub use deployment::{Deployment, DeploymentStatus, ProgressInfo}; +pub use error::{BasilicaError, DeploymentError, Result}; +pub use model_size::{estimate_gpu_requirements, GpuRequirements}; +pub use source::SourcePackager; + +// Re-export commonly used types from low-level SDK +pub use basilica_sdk::{ + BalanceResponse, CpuOffering, GpuOffering, GpuRequirementsSpec, HealthCheckConfig, + HealthCheckResponse, ListAvailableNodesQuery, ProbeConfig, ResourceRequirements, + SshKeyResponse, StorageBackend, StorageSpec, TopologySpreadConfig, UsageHistoryResponse, +}; + +/// SDK version +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/crates/basilica-sdk-rust/src/model_size.rs b/crates/basilica-sdk-rust/src/model_size.rs new file mode 100644 index 00000000..754c4f3f --- /dev/null +++ b/crates/basilica-sdk-rust/src/model_size.rs @@ -0,0 +1,394 @@ +//! GPU requirements estimation based on model size. +//! +//! This module provides utilities for estimating GPU requirements based on +//! model names and parameter counts. It helps users choose appropriate GPU +//! configurations for their machine learning models. + +use once_cell::sync::Lazy; +use regex::Regex; + +/// Default GPU memory to assume for calculations (in GB). +const DEFAULT_GPU_MEMORY_GB: u32 = 16; + +/// Regex pattern to extract parameter counts from model names (e.g., "7b", "0.5b", "70B"). +static PARAM_PATTERN: Lazy = Lazy::new(|| { + Regex::new(r"(?i)(\d+\.?\d*)b").expect("Invalid regex pattern for parameter extraction") +}); + +/// GPU requirements estimation for a machine learning model. +/// +/// This struct contains the estimated GPU configuration needed to run a model, +/// including the number of GPUs, memory requirements, and recommended GPU type. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GpuRequirements { + /// Estimated number of GPUs required. + pub gpu_count: u32, + + /// Estimated GPU memory required in gigabytes. + pub memory_gb: u32, + + /// Recommended GPU model for this workload. + pub recommended_gpu: String, +} + +impl GpuRequirements { + /// Create a new GpuRequirements instance. + pub fn new(gpu_count: u32, memory_gb: u32, recommended_gpu: impl Into) -> Self { + Self { + gpu_count, + memory_gb, + recommended_gpu: recommended_gpu.into(), + } + } +} + +/// Estimate GPU requirements based on a model name. +/// +/// This function analyzes the model name to determine GPU requirements. +/// It uses two approaches: +/// +/// 1. **Parameter extraction**: If the model name contains a parameter count +/// (e.g., "llama-7b", "qwen2.5-0.5b"), it uses that to estimate memory. +/// +/// 2. **Model family matching**: For known model families without explicit +/// parameter counts, it uses predefined estimates. +/// +/// # Arguments +/// +/// * `model` - The model name or identifier (e.g., "meta-llama/Llama-2-7b") +/// +/// # Returns +/// +/// A [`GpuRequirements`] struct with estimated GPU configuration. +/// +/// # Example +/// +/// ```rust +/// use basilica_sdk_rust::estimate_gpu_requirements; +/// +/// let reqs = estimate_gpu_requirements("meta-llama/Llama-2-7b"); +/// println!("Need {} GPU(s) with {}GB memory", reqs.gpu_count, reqs.memory_gb); +/// println!("Recommended: {}", reqs.recommended_gpu); +/// ``` +pub fn estimate_gpu_requirements(model: &str) -> GpuRequirements { + let model_lower = model.to_lowercase(); + + // Try to extract parameter count from model name + let memory_gb = if let Some(params) = extract_param_count(&model_lower) { + estimate_memory_from_params(params) + } else { + estimate_from_model_family(&model_lower) + }; + + let gpu_count = calculate_gpu_count(memory_gb, DEFAULT_GPU_MEMORY_GB); + let recommended_gpu = recommend_gpu(memory_gb); + + GpuRequirements { + gpu_count, + memory_gb, + recommended_gpu, + } +} + +/// Extract parameter count (in billions) from a model name. +/// +/// Looks for patterns like "7b", "0.5B", "70B" in the model name. +/// +/// # Arguments +/// +/// * `model` - The lowercase model name to search +/// +/// # Returns +/// +/// The parameter count as a floating-point number, or `None` if not found. +fn extract_param_count(model: &str) -> Option { + PARAM_PATTERN + .captures(model) + .and_then(|caps: regex::Captures<'_>| { + caps.get(1) + .and_then(|m: regex::Match<'_>| m.as_str().parse::().ok()) + }) +} + +/// Estimate GPU memory requirements from parameter count. +/// +/// Uses the formula: `params_billions * 2.0 * 1.2` (2 bytes per param, 20% overhead). +/// Result is rounded up to the nearest 8GB. +/// +/// # Arguments +/// +/// * `params_billions` - Model size in billions of parameters +/// +/// # Returns +/// +/// Estimated memory requirement in gigabytes. +fn estimate_memory_from_params(params_billions: f64) -> u32 { + // Base memory calculation: + // - 2 bytes per parameter (FP16 inference) + // - 1.2x overhead for KV cache, activations, etc. + let base = params_billions * 2.0 * 1.2; + + // Round up to nearest 8GB for practical allocation + let base_u32 = base.ceil() as u32; + base_u32.div_ceil(8) * 8 +} + +/// Estimate GPU memory for known model families without explicit param counts. +/// +/// # Arguments +/// +/// * `model` - The lowercase model name to match +/// +/// # Returns +/// +/// Estimated memory requirement in gigabytes. +fn estimate_from_model_family(model: &str) -> u32 { + // Large models (~70B parameters) + if ["llama-2-70b", "llama-70b", "mixtral", "qwen-72b"] + .iter() + .any(|pattern| model.contains(pattern)) + { + return 160; // ~70B params need ~168GB, round to 160 + } + + // Medium-large models (~13-34B parameters) + if ["llama-2-13b", "llama-13b", "codellama-34b"] + .iter() + .any(|pattern| model.contains(pattern)) + { + return 32; // ~13-34B params + } + + // Medium models (~7B parameters) + if ["llama-2-7b", "llama-7b", "mistral-7b", "qwen-7b"] + .iter() + .any(|pattern| model.contains(pattern)) + { + return 16; // ~7B params + } + + // Small models (<3B parameters) + if [ + "phi-2", + "gemma-2b", + "tinyllama", + "qwen3-0.6b", + "qwen2.5-0.5b", + ] + .iter() + .any(|pattern| model.contains(pattern)) + { + return 8; // Small models + } + + // Default: assume medium-sized model + 16 +} + +/// Calculate the number of GPUs needed for a given memory requirement. +/// +/// # Arguments +/// +/// * `required_memory_gb` - Total memory required in gigabytes +/// * `gpu_memory_gb` - Memory per GPU in gigabytes +/// +/// # Returns +/// +/// Number of GPUs needed (clamped between 1 and 8). +fn calculate_gpu_count(required_memory_gb: u32, gpu_memory_gb: u32) -> u32 { + // Calculate minimum GPUs needed to meet memory requirement + let count = required_memory_gb.div_ceil(gpu_memory_gb); + + // Clamp to reasonable range (1-8 GPUs) + count.clamp(1, 8) +} + +/// Recommend an appropriate GPU model based on memory requirements. +/// +/// # Arguments +/// +/// * `memory_gb` - Required GPU memory in gigabytes +/// +/// # Returns +/// +/// Recommended GPU model name. +fn recommend_gpu(memory_gb: u32) -> String { + if memory_gb <= 16 { + "NVIDIA-RTX-A4000".to_string() + } else if memory_gb <= 40 { + "A100-40GB".to_string() + } else if memory_gb <= 80 { + "A100-80GB".to_string() + } else { + "H100".to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_params_7b() { + assert_eq!(extract_param_count("llama-7b"), Some(7.0)); + } + + #[test] + fn test_extract_params_70b() { + assert_eq!(extract_param_count("llama-70b"), Some(70.0)); + } + + #[test] + fn test_extract_params_0_5b() { + assert_eq!(extract_param_count("qwen2.5-0.5b"), Some(0.5)); + } + + #[test] + fn test_extract_params_uppercase() { + assert_eq!(extract_param_count("model-7B"), Some(7.0)); + } + + #[test] + fn test_extract_params_none() { + assert_eq!(extract_param_count("unknown-model"), None); + } + + #[test] + fn test_extract_params_with_path() { + assert_eq!( + extract_param_count("meta-llama/llama-2-7b-chat-hf"), + Some(7.0) + ); + } + + #[test] + fn test_estimate_memory_small_model() { + // 0.5B params: 0.5 * 2 * 1.2 = 1.2GB -> rounds to 8GB + let memory = estimate_memory_from_params(0.5); + assert_eq!(memory, 8); + } + + #[test] + fn test_estimate_memory_7b_model() { + // 7B params: 7 * 2 * 1.2 = 16.8GB -> rounds to 24GB + let memory = estimate_memory_from_params(7.0); + assert_eq!(memory, 24); + } + + #[test] + fn test_estimate_memory_70b_model() { + // 70B params: 70 * 2 * 1.2 = 168GB -> rounds to 168GB + let memory = estimate_memory_from_params(70.0); + assert_eq!(memory, 168); + } + + #[test] + fn test_estimate_gpu_requirements_7b() { + let reqs = estimate_gpu_requirements("meta-llama/Llama-2-7b"); + // 7B -> 24GB memory, needs 2 GPUs with 16GB each + assert!(reqs.gpu_count >= 1); + assert!(reqs.memory_gb >= 8 && reqs.memory_gb <= 32); + } + + #[test] + fn test_estimate_gpu_requirements_70b() { + let reqs = estimate_gpu_requirements("meta-llama/Llama-2-70b"); + // 70B -> ~168GB memory, needs multiple GPUs + assert!(reqs.gpu_count > 1); + assert!(reqs.memory_gb > 80); + assert_eq!(reqs.recommended_gpu, "H100"); + } + + #[test] + fn test_estimate_gpu_requirements_small_model() { + let reqs = estimate_gpu_requirements("qwen2.5-0.5b"); + assert_eq!(reqs.gpu_count, 1); + assert!(reqs.memory_gb <= 16); + } + + #[test] + fn test_estimate_gpu_requirements_unknown() { + let reqs = estimate_gpu_requirements("unknown-custom-model"); + // Should use default (medium model assumption) + assert_eq!(reqs.gpu_count, 1); + assert_eq!(reqs.memory_gb, 16); + } + + #[test] + fn test_calculate_gpu_count_single() { + assert_eq!(calculate_gpu_count(8, 16), 1); + assert_eq!(calculate_gpu_count(16, 16), 1); + } + + #[test] + fn test_calculate_gpu_count_multiple() { + assert_eq!(calculate_gpu_count(32, 16), 2); + assert_eq!(calculate_gpu_count(48, 16), 3); + } + + #[test] + fn test_calculate_gpu_count_clamped() { + // Should clamp to max of 8 + assert_eq!(calculate_gpu_count(200, 16), 8); + } + + #[test] + fn test_recommend_gpu_small() { + assert_eq!(recommend_gpu(8), "NVIDIA-RTX-A4000"); + assert_eq!(recommend_gpu(16), "NVIDIA-RTX-A4000"); + } + + #[test] + fn test_recommend_gpu_medium() { + assert_eq!(recommend_gpu(24), "A100-40GB"); + assert_eq!(recommend_gpu(40), "A100-40GB"); + } + + #[test] + fn test_recommend_gpu_large() { + assert_eq!(recommend_gpu(48), "A100-80GB"); + assert_eq!(recommend_gpu(80), "A100-80GB"); + } + + #[test] + fn test_recommend_gpu_very_large() { + assert_eq!(recommend_gpu(96), "H100"); + assert_eq!(recommend_gpu(160), "H100"); + } + + #[test] + fn test_model_family_mixtral() { + // Note: mixtral-8x7b contains "7b" which matches param pattern + // The regex extracts 7B params, giving ~24GB estimate + // This is intentional - the param extraction takes precedence + let reqs = estimate_gpu_requirements("mixtral-8x7b-instruct"); + assert!(reqs.memory_gb >= 16); + + // Pure "mixtral" without param count falls back to model family + let reqs_plain = estimate_gpu_requirements("mixtral-moe"); + assert_eq!(reqs_plain.memory_gb, 160); + } + + #[test] + fn test_model_family_phi2() { + let reqs = estimate_gpu_requirements("microsoft/phi-2"); + assert_eq!(reqs.memory_gb, 8); + assert_eq!(reqs.gpu_count, 1); + } + + #[test] + fn test_gpu_requirements_clone() { + let reqs = estimate_gpu_requirements("llama-7b"); + let cloned = reqs.clone(); + assert_eq!(reqs, cloned); + } + + #[test] + fn test_gpu_requirements_debug() { + let reqs = GpuRequirements::new(2, 48, "A100-80GB"); + let debug_str = format!("{:?}", reqs); + assert!(debug_str.contains("gpu_count: 2")); + assert!(debug_str.contains("memory_gb: 48")); + assert!(debug_str.contains("A100-80GB")); + } +} diff --git a/crates/basilica-sdk-rust/src/source.rs b/crates/basilica-sdk-rust/src/source.rs new file mode 100644 index 00000000..3230f320 --- /dev/null +++ b/crates/basilica-sdk-rust/src/source.rs @@ -0,0 +1,503 @@ +//! Source code packaging for container deployments. +//! +//! This module provides [`SourcePackager`] for packaging Python source code +//! into container-ready commands with automatic framework detection. + +use crate::error::{BasilicaError, Result}; +use std::fs; +use std::path::Path; + +/// Packages Python source code for container deployment. +/// +/// `SourcePackager` handles loading source code from files or strings, +/// detecting web frameworks, and generating the appropriate container +/// commands to run the code. +/// +/// # Example +/// +/// ```rust +/// use basilica_sdk_rust::SourcePackager; +/// +/// // From inline code +/// let packager = SourcePackager::from_string(r#" +/// from fastapi import FastAPI +/// app = FastAPI() +/// +/// @app.get("/") +/// def hello(): +/// return {"message": "Hello, World!"} +/// "#).unwrap(); +/// +/// assert_eq!(packager.detect_framework(), Some("fastapi")); +/// ``` +#[derive(Debug)] +pub struct SourcePackager { + /// Original source (file path or inline code). + pub source: String, + + /// The actual source code content. + pub code: String, + + /// Whether the source was loaded from a file. + pub is_file: bool, +} + +impl SourcePackager { + /// Default packages for web applications using FastAPI. + const WEB_PACKAGES: &'static [&'static str] = &["fastapi", "uvicorn", "pydantic"]; + + /// Create a new SourcePackager from a file path or inline code. + /// + /// This constructor auto-detects whether the input is a file path or + /// inline code: + /// - If the string ends with `.py`, it's treated as a file path + /// - If the string points to an existing file, it's loaded + /// - Otherwise, it's treated as inline code + /// + /// # Arguments + /// + /// * `source` - Either a file path or inline Python code + /// + /// # Errors + /// + /// Returns an error if: + /// - The file doesn't exist or can't be read + /// - The source code is empty + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::SourcePackager; + /// + /// // From inline code (auto-detected) + /// let packager = SourcePackager::new("print('hello')").unwrap(); + /// assert!(!packager.is_file); + /// ``` + pub fn new(source: impl AsRef) -> Result { + let source = source.as_ref().to_string(); + let (code, is_file) = if Self::is_file_path(&source) { + (Self::read_file(&source)?, true) + } else { + (source.clone(), false) + }; + + if code.trim().is_empty() { + return Err(BasilicaError::Source { + message: "Source code is empty".to_string(), + source_path: if is_file { Some(source) } else { None }, + }); + } + + Ok(Self { + source, + code, + is_file, + }) + } + + /// Create a SourcePackager from an explicit file path. + /// + /// Unlike [`new`](Self::new), this always treats the input as a file path, + /// regardless of its contents or extension. + /// + /// # Arguments + /// + /// * `path` - Path to the Python source file + /// + /// # Errors + /// + /// Returns an error if the file doesn't exist or can't be read. + /// + /// # Example + /// + /// ```rust,ignore + /// use basilica_sdk_rust::SourcePackager; + /// + /// let packager = SourcePackager::from_file("./app.py")?; + /// assert!(packager.is_file); + /// ``` + pub fn from_file(path: impl AsRef) -> Result { + let path = path.as_ref(); + let path_str = path.display().to_string(); + + if !path.exists() { + return Err(BasilicaError::Source { + message: format!("Source file '{}' not found", path_str), + source_path: Some(path_str), + }); + } + + let code = Self::read_file(&path_str)?; + Ok(Self { + source: path_str, + code, + is_file: true, + }) + } + + /// Create a SourcePackager from inline code. + /// + /// This always treats the input as code, never as a file path. + /// + /// # Arguments + /// + /// * `code` - Python source code as a string + /// + /// # Errors + /// + /// Returns an error if the code is empty. + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::SourcePackager; + /// + /// let packager = SourcePackager::from_string(r#" + /// import flask + /// app = flask.Flask(__name__) + /// "#).unwrap(); + /// + /// assert!(!packager.is_file); + /// assert_eq!(packager.detect_framework(), Some("flask")); + /// ``` + pub fn from_string(code: impl Into) -> Result { + let code = code.into(); + if code.trim().is_empty() { + return Err(BasilicaError::Source { + message: "Source code is empty".to_string(), + source_path: None, + }); + } + + Ok(Self { + source: code.clone(), + code, + is_file: false, + }) + } + + /// Detect the web framework used in the source code. + /// + /// Scans the source code for import statements to identify + /// the web framework being used. + /// + /// # Returns + /// + /// The detected framework name, or `None` if no known framework is found. + /// + /// Currently detects: + /// - `"fastapi"` - FastAPI framework + /// - `"flask"` - Flask framework + /// - `"django"` - Django framework + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::SourcePackager; + /// + /// let packager = SourcePackager::from_string("from fastapi import FastAPI").unwrap(); + /// assert_eq!(packager.detect_framework(), Some("fastapi")); + /// + /// let packager = SourcePackager::from_string("print('hello')").unwrap(); + /// assert_eq!(packager.detect_framework(), None); + /// ``` + pub fn detect_framework(&self) -> Option<&'static str> { + let code_lower = self.code.to_lowercase(); + + if code_lower.contains("from fastapi") || code_lower.contains("import fastapi") { + return Some("fastapi"); + } + if code_lower.contains("from flask") || code_lower.contains("import flask") { + return Some("flask"); + } + if code_lower.contains("from django") || code_lower.contains("import django") { + return Some("django"); + } + + None + } + + /// Build the container command to run the Python source. + /// + /// Generates a shell command that: + /// 1. Optionally installs pip packages + /// 2. Runs the Python code using a heredoc + /// + /// If no packages are specified and FastAPI is detected, it automatically + /// includes the standard web packages (fastapi, uvicorn, pydantic). + /// + /// # Arguments + /// + /// * `pip_packages` - Optional list of pip packages to install + /// + /// # Returns + /// + /// A vector of strings representing the command: `["bash", "-c", "..."]` + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::SourcePackager; + /// + /// let packager = SourcePackager::from_string("print('hello')").unwrap(); + /// let cmd = packager.build_command(None); + /// + /// assert_eq!(cmd[0], "bash"); + /// assert_eq!(cmd[1], "-c"); + /// assert!(cmd[2].contains("python3")); + /// ``` + pub fn build_command(&self, pip_packages: Option<&[String]>) -> Vec { + let mut parts = Vec::new(); + + // Determine packages to install + let packages: Option> = match pip_packages { + Some(p) if !p.is_empty() => Some(p.to_vec()), + None => { + // Auto-detect packages for FastAPI + if self.detect_framework() == Some("fastapi") { + Some(Self::WEB_PACKAGES.iter().map(|s| s.to_string()).collect()) + } else { + None + } + } + _ => None, + }; + + // Add pip install command if packages are specified + if let Some(pkgs) = packages { + if !pkgs.is_empty() { + let packages_str = pkgs.join(" "); + parts.push(format!("pip install -q {}", packages_str)); + } + } + + // Build the python command with heredoc + let heredoc = format!("python3 - <<'PYCODE'\n{}\nPYCODE\n", self.code); + parts.push(heredoc); + + // Join commands with && so pip install failure stops execution + let full_command = parts.join(" && "); + + vec!["bash".to_string(), "-c".to_string(), full_command] + } + + /// Check if a string looks like a file path. + fn is_file_path(source: &str) -> bool { + // If it ends with .py, treat as file path + if source.ends_with(".py") { + return true; + } + // Check if the path exists as a file + Path::new(source).is_file() + } + + /// Read a file, expanding ~ to home directory. + fn read_file(path: &str) -> Result { + let expanded = shellexpand::tilde(path).to_string(); + fs::read_to_string(&expanded).map_err(|e| BasilicaError::Source { + message: format!("Failed to read source file '{}': {}", path, e), + source_path: Some(path.to_string()), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn test_from_string() { + let packager = SourcePackager::from_string("print('hello')").expect("Should succeed"); + assert!(!packager.is_file); + assert_eq!(packager.code, "print('hello')"); + } + + #[test] + fn test_from_string_empty_fails() { + let result = SourcePackager::from_string(""); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, BasilicaError::Source { .. })); + } + + #[test] + fn test_from_string_whitespace_only_fails() { + let result = SourcePackager::from_string(" \n\t "); + assert!(result.is_err()); + } + + #[test] + fn test_detect_framework_fastapi() { + let packager = + SourcePackager::from_string("from fastapi import FastAPI").expect("Should succeed"); + assert_eq!(packager.detect_framework(), Some("fastapi")); + } + + #[test] + fn test_detect_framework_fastapi_import() { + let packager = SourcePackager::from_string("import fastapi\napp = fastapi.FastAPI()") + .expect("Should succeed"); + assert_eq!(packager.detect_framework(), Some("fastapi")); + } + + #[test] + fn test_detect_framework_flask() { + let packager = + SourcePackager::from_string("from flask import Flask").expect("Should succeed"); + assert_eq!(packager.detect_framework(), Some("flask")); + } + + #[test] + fn test_detect_framework_django() { + let packager = SourcePackager::from_string("from django.http import HttpResponse") + .expect("Should succeed"); + assert_eq!(packager.detect_framework(), Some("django")); + } + + #[test] + fn test_detect_framework_none() { + let packager = SourcePackager::from_string("print('hello world')").expect("Should succeed"); + assert_eq!(packager.detect_framework(), None); + } + + #[test] + fn test_detect_framework_case_insensitive() { + let packager = + SourcePackager::from_string("FROM FASTAPI IMPORT FASTAPI").expect("Should succeed"); + assert_eq!(packager.detect_framework(), Some("fastapi")); + } + + #[test] + fn test_build_command_no_packages() { + let packager = SourcePackager::from_string("print('test')").expect("Should succeed"); + let cmd = packager.build_command(None); + + assert_eq!(cmd.len(), 3); + assert_eq!(cmd[0], "bash"); + assert_eq!(cmd[1], "-c"); + assert!(cmd[2].contains("python3")); + assert!(cmd[2].contains("PYCODE")); + assert!(cmd[2].contains("print('test')")); + } + + #[test] + fn test_build_command_with_packages() { + let packager = SourcePackager::from_string("print('test')").expect("Should succeed"); + let packages = vec!["requests".to_string(), "numpy".to_string()]; + let cmd = packager.build_command(Some(&packages)); + + assert!(cmd[2].contains("pip install -q requests numpy")); + assert!(cmd[2].contains("&&")); + assert!(cmd[2].contains("python3")); + } + + #[test] + fn test_build_command_auto_packages_fastapi() { + let packager = SourcePackager::from_string("from fastapi import FastAPI\napp = FastAPI()") + .expect("Should succeed"); + let cmd = packager.build_command(None); + + // Should auto-detect FastAPI and include web packages + assert!(cmd[2].contains("pip install -q")); + assert!(cmd[2].contains("fastapi")); + assert!(cmd[2].contains("uvicorn")); + } + + #[test] + fn test_build_command_empty_packages_no_install() { + let packager = SourcePackager::from_string("print('test')").expect("Should succeed"); + let packages: Vec = vec![]; + let cmd = packager.build_command(Some(&packages)); + + // Empty packages should not add pip install + assert!(!cmd[2].contains("pip install")); + } + + #[test] + fn test_from_file() { + let mut file = NamedTempFile::new().expect("Failed to create temp file"); + writeln!(file, "print('from file')").expect("Failed to write"); + + let packager = SourcePackager::from_file(file.path()).expect("Should succeed"); + assert!(packager.is_file); + assert!(packager.code.contains("print('from file')")); + } + + #[test] + fn test_from_file_not_found() { + let result = SourcePackager::from_file("/nonexistent/path/to/app.py"); + assert!(result.is_err()); + + let err = result.unwrap_err(); + if let BasilicaError::Source { + message, + source_path, + } = err + { + assert!(message.contains("not found")); + assert!(source_path.is_some()); + } else { + panic!("Expected Source error"); + } + } + + #[test] + fn test_new_with_inline_code() { + let packager = SourcePackager::new("x = 1 + 2").expect("Should succeed"); + assert!(!packager.is_file); + assert_eq!(packager.code, "x = 1 + 2"); + } + + #[test] + fn test_new_with_file_path() { + let mut file = NamedTempFile::with_suffix(".py").expect("Failed to create temp file"); + writeln!(file, "# Python file").expect("Failed to write"); + + let packager = SourcePackager::new(file.path().to_str().unwrap()).expect("Should succeed"); + assert!(packager.is_file); + } + + #[test] + fn test_is_file_path_py_extension() { + assert!(SourcePackager::is_file_path("app.py")); + assert!(SourcePackager::is_file_path("/path/to/script.py")); + assert!(!SourcePackager::is_file_path("print('hello')")); + } + + #[test] + fn test_heredoc_escaping() { + // Test that the heredoc handles code with quotes properly + let packager = + SourcePackager::from_string(r#"print("hello 'world'")"#).expect("Should succeed"); + let cmd = packager.build_command(None); + + // The PYCODE delimiter should allow any content + assert!(cmd[2].contains("PYCODE")); + assert!(cmd[2].contains(r#"print("hello 'world'")"#)); + } + + #[test] + fn test_multiline_code() { + let code = r#" +def hello(): + return "Hello, World!" + +if __name__ == "__main__": + print(hello()) +"#; + let packager = SourcePackager::from_string(code).expect("Should succeed"); + assert!(packager.code.contains("def hello():")); + assert!(packager.code.contains("if __name__")); + } + + #[test] + fn test_debug_output() { + let packager = SourcePackager::from_string("x = 1").expect("Should succeed"); + let debug_str = format!("{:?}", packager); + assert!(debug_str.contains("source")); + assert!(debug_str.contains("code")); + assert!(debug_str.contains("is_file")); + } +} diff --git a/crates/basilica-sdk-rust/src/spec.rs b/crates/basilica-sdk-rust/src/spec.rs new file mode 100644 index 00000000..6cd5d03d --- /dev/null +++ b/crates/basilica-sdk-rust/src/spec.rs @@ -0,0 +1,438 @@ +//! Deployment specification types. +//! +//! This module provides [`DeploymentSpec`], an immutable configuration +//! for defining how deployments should be created and configured. + +use crate::Volume; +use basilica_sdk::HealthCheckConfig; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Immutable specification for a deployment. +/// +/// `DeploymentSpec` defines all the configuration needed to create a deployment +/// on the Basilica GPU cloud platform. It includes container settings, resource +/// requirements, networking, and optional GPU specifications. +/// +/// # Example +/// +/// ```rust +/// use basilica_sdk_rust::DeploymentSpec; +/// +/// let spec = DeploymentSpec { +/// name: "my-inference-service".to_string(), +/// image: "python:3.11-slim".to_string(), +/// port: 8000, +/// gpu_count: Some(1), +/// gpu_models: Some(vec!["NVIDIA-RTX-A4000".to_string()]), +/// ..Default::default() +/// }; +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeploymentSpec { + /// The unique name for this deployment. + /// + /// Must be a valid DNS subdomain name (lowercase alphanumeric with hyphens). + pub name: String, + + /// Container image to deploy. + /// + /// Can be a public image (e.g., `python:3.11-slim`) or a private image + /// from a container registry. + pub image: String, + + /// The port the application listens on inside the container. + /// + /// This port will be exposed through the deployment's public URL. + pub port: u32, + + /// CPU resource allocation (Kubernetes format). + /// + /// Examples: `"500m"` (half a core), `"2"` (two cores), `"4000m"` (four cores). + pub cpu: String, + + /// Memory resource allocation (Kubernetes format). + /// + /// Examples: `"512Mi"`, `"2Gi"`, `"4096Mi"`. + pub memory: String, + + /// Number of GPUs to allocate (if any). + /// + /// When specified, the deployment will be scheduled on GPU nodes. + pub gpu_count: Option, + + /// Acceptable GPU models for this deployment. + /// + /// If specified, the deployment will only be scheduled on nodes with + /// one of these GPU types. Example: `["NVIDIA-RTX-A4000", "A100-40GB"]`. + pub gpu_models: Option>, + + /// Minimum CUDA version required. + /// + /// Example: `"12.0"` to require CUDA 12.0 or higher. + pub min_cuda_version: Option, + + /// Minimum GPU memory required in gigabytes. + /// + /// Useful for ensuring your model fits in GPU memory. + pub min_gpu_memory_gb: Option, + + /// Persistent volumes to mount into the container. + /// + /// Keys are mount paths inside the container, values are volume specifications. + pub volumes: Option>, + + /// Environment variables to set in the container. + pub env: Option>, + + /// Python packages to install at container startup. + /// + /// These will be installed via pip before the application starts. + pub pip_packages: Option>, + + /// Number of replica pods to run. + /// + /// Higher replica counts provide better availability and throughput. + pub replicas: u32, + + /// Time-to-live in seconds after which the deployment auto-terminates. + /// + /// Useful for temporary or time-limited deployments. + pub ttl_seconds: Option, + + /// Whether the deployment URL is publicly accessible without authentication. + pub public: bool, + + /// Timeout in seconds for deployment operations (default: 300). + /// + /// This controls how long to wait for the deployment to become ready. + pub timeout: u64, + + /// Health check configuration for the deployment. + /// + /// Configures liveness, readiness, and startup probes. + pub health_check: Option, +} + +impl Default for DeploymentSpec { + fn default() -> Self { + Self { + name: String::new(), + image: "python:3.11-slim".to_string(), + port: 8000, + cpu: "500m".to_string(), + memory: "512Mi".to_string(), + gpu_count: None, + gpu_models: None, + min_cuda_version: None, + min_gpu_memory_gb: None, + volumes: None, + env: None, + pip_packages: None, + replicas: 1, + ttl_seconds: None, + public: true, + timeout: 300, + health_check: None, + } + } +} + +impl DeploymentSpec { + /// Create a new deployment spec with the given name and image. + /// + /// All other fields are set to their defaults. + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::DeploymentSpec; + /// + /// let spec = DeploymentSpec::new("my-app", "python:3.11"); + /// assert_eq!(spec.name, "my-app"); + /// assert_eq!(spec.image, "python:3.11"); + /// ``` + pub fn new(name: impl Into, image: impl Into) -> Self { + Self { + name: name.into(), + image: image.into(), + ..Default::default() + } + } + + /// Create a deployment spec configured for GPU workloads. + /// + /// # Arguments + /// + /// * `name` - Deployment name + /// * `image` - Container image (typically a CUDA-enabled image) + /// * `gpu_count` - Number of GPUs to allocate + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::DeploymentSpec; + /// + /// let spec = DeploymentSpec::with_gpu("llm-service", "nvidia/cuda:12.0-base", 1); + /// assert_eq!(spec.gpu_count, Some(1)); + /// ``` + pub fn with_gpu(name: impl Into, image: impl Into, gpu_count: u32) -> Self { + Self { + name: name.into(), + image: image.into(), + gpu_count: Some(gpu_count), + // Increase default resources for GPU workloads + cpu: "2".to_string(), + memory: "8Gi".to_string(), + ..Default::default() + } + } + + /// Set the number of replicas. + pub fn with_replicas(mut self, replicas: u32) -> Self { + self.replicas = replicas; + self + } + + /// Set the port the application listens on. + pub fn with_port(mut self, port: u32) -> Self { + self.port = port; + self + } + + /// Set CPU and memory resources. + pub fn with_resources(mut self, cpu: impl Into, memory: impl Into) -> Self { + self.cpu = cpu.into(); + self.memory = memory.into(); + self + } + + /// Add environment variables. + pub fn with_env(mut self, env: HashMap) -> Self { + self.env = Some(env); + self + } + + /// Add a single environment variable. + pub fn add_env(mut self, key: impl Into, value: impl Into) -> Self { + let env = self.env.get_or_insert_with(HashMap::new); + env.insert(key.into(), value.into()); + self + } + + /// Set pip packages to install. + pub fn with_pip_packages(mut self, packages: Vec) -> Self { + self.pip_packages = Some(packages); + self + } + + /// Set GPU requirements. + pub fn with_gpu_requirements( + mut self, + gpu_count: u32, + gpu_models: Option>, + min_cuda_version: Option, + min_gpu_memory_gb: Option, + ) -> Self { + self.gpu_count = Some(gpu_count); + self.gpu_models = gpu_models; + self.min_cuda_version = min_cuda_version; + self.min_gpu_memory_gb = min_gpu_memory_gb; + self + } + + /// Set the deployment as private (requires authentication to access). + pub fn as_private(mut self) -> Self { + self.public = false; + self + } + + /// Set a time-to-live for automatic cleanup. + pub fn with_ttl(mut self, seconds: u32) -> Self { + self.ttl_seconds = Some(seconds); + self + } + + /// Set the timeout for deployment operations. + pub fn with_timeout(mut self, seconds: u64) -> Self { + self.timeout = seconds; + self + } + + /// Add a volume mount. + pub fn add_volume(mut self, mount_path: impl Into, volume: Volume) -> Self { + let volumes = self.volumes.get_or_insert_with(HashMap::new); + volumes.insert(mount_path.into(), volume); + self + } + + /// Set health check configuration. + pub fn with_health_check(mut self, health_check: HealthCheckConfig) -> Self { + self.health_check = Some(health_check); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_values() { + let spec = DeploymentSpec::default(); + assert_eq!(spec.name, ""); + assert_eq!(spec.image, "python:3.11-slim"); + assert_eq!(spec.port, 8000); + assert_eq!(spec.cpu, "500m"); + assert_eq!(spec.memory, "512Mi"); + assert_eq!(spec.replicas, 1); + assert!(spec.public); + assert_eq!(spec.timeout, 300); + assert!(spec.gpu_count.is_none()); + assert!(spec.env.is_none()); + assert!(spec.volumes.is_none()); + } + + #[test] + fn test_new_constructor() { + let spec = DeploymentSpec::new("test-app", "node:18"); + assert_eq!(spec.name, "test-app"); + assert_eq!(spec.image, "node:18"); + assert_eq!(spec.port, 8000); // default + } + + #[test] + fn test_with_gpu_constructor() { + let spec = DeploymentSpec::with_gpu("gpu-app", "nvidia/cuda:12.0", 2); + assert_eq!(spec.name, "gpu-app"); + assert_eq!(spec.image, "nvidia/cuda:12.0"); + assert_eq!(spec.gpu_count, Some(2)); + assert_eq!(spec.cpu, "2"); + assert_eq!(spec.memory, "8Gi"); + } + + #[test] + fn test_builder_pattern() { + let spec = DeploymentSpec::new("builder-test", "python:3.11") + .with_port(9000) + .with_replicas(3) + .with_resources("1", "2Gi") + .as_private() + .with_ttl(3600) + .with_timeout(600); + + assert_eq!(spec.port, 9000); + assert_eq!(spec.replicas, 3); + assert_eq!(spec.cpu, "1"); + assert_eq!(spec.memory, "2Gi"); + assert!(!spec.public); + assert_eq!(spec.ttl_seconds, Some(3600)); + assert_eq!(spec.timeout, 600); + } + + #[test] + fn test_add_env() { + let spec = DeploymentSpec::new("env-test", "python:3.11") + .add_env("API_KEY", "secret123") + .add_env("DEBUG", "true"); + + let env = spec.env.expect("env should be Some"); + assert_eq!(env.get("API_KEY"), Some(&"secret123".to_string())); + assert_eq!(env.get("DEBUG"), Some(&"true".to_string())); + } + + #[test] + fn test_with_env() { + let mut env = HashMap::new(); + env.insert("KEY1".to_string(), "value1".to_string()); + env.insert("KEY2".to_string(), "value2".to_string()); + + let spec = DeploymentSpec::new("env-test", "python:3.11").with_env(env); + + let spec_env = spec.env.expect("env should be Some"); + assert_eq!(spec_env.len(), 2); + } + + #[test] + fn test_with_pip_packages() { + let spec = DeploymentSpec::new("pip-test", "python:3.11") + .with_pip_packages(vec!["fastapi".to_string(), "uvicorn".to_string()]); + + let packages = spec.pip_packages.expect("pip_packages should be Some"); + assert_eq!(packages.len(), 2); + assert!(packages.contains(&"fastapi".to_string())); + } + + #[test] + fn test_with_gpu_requirements() { + let spec = DeploymentSpec::new("gpu-req-test", "nvidia/cuda:12.0") + .with_gpu_requirements( + 2, + Some(vec!["A100-40GB".to_string(), "A100-80GB".to_string()]), + Some("12.0".to_string()), + Some(40), + ); + + assert_eq!(spec.gpu_count, Some(2)); + assert_eq!(spec.gpu_models.as_ref().map(|v| v.len()), Some(2)); + assert_eq!(spec.min_cuda_version, Some("12.0".to_string())); + assert_eq!(spec.min_gpu_memory_gb, Some(40)); + } + + #[test] + fn test_add_volume() { + let vol = Volume::new("data-vol"); + let spec = DeploymentSpec::new("vol-test", "python:3.11") + .add_volume("/data", vol); + + let volumes = spec.volumes.expect("volumes should be Some"); + assert!(volumes.contains_key("/data")); + assert!(volumes.get("/data").unwrap().create_if_missing); + } + + #[test] + fn test_serialization() { + let spec = DeploymentSpec::new("ser-test", "python:3.11") + .with_port(9000) + .with_replicas(2); + + let json = serde_json::to_string(&spec).expect("Serialization failed"); + assert!(json.contains("\"name\":\"ser-test\"")); + assert!(json.contains("\"port\":9000")); + assert!(json.contains("\"replicas\":2")); + } + + #[test] + fn test_deserialization() { + let json = r#"{ + "name": "deser-test", + "image": "node:18", + "port": 3000, + "cpu": "1", + "memory": "1Gi", + "replicas": 2, + "public": false, + "timeout": 120 + }"#; + + let spec: DeploymentSpec = serde_json::from_str(json).expect("Deserialization failed"); + assert_eq!(spec.name, "deser-test"); + assert_eq!(spec.image, "node:18"); + assert_eq!(spec.port, 3000); + assert_eq!(spec.cpu, "1"); + assert_eq!(spec.memory, "1Gi"); + assert_eq!(spec.replicas, 2); + assert!(!spec.public); + assert_eq!(spec.timeout, 120); + } + + #[test] + fn test_clone() { + let spec = DeploymentSpec::new("clone-test", "python:3.11") + .add_env("KEY", "VALUE"); + let cloned = spec.clone(); + + assert_eq!(spec.name, cloned.name); + assert_eq!(spec.env, cloned.env); + } +} diff --git a/crates/basilica-sdk-rust/src/volume.rs b/crates/basilica-sdk-rust/src/volume.rs new file mode 100644 index 00000000..c30279f8 --- /dev/null +++ b/crates/basilica-sdk-rust/src/volume.rs @@ -0,0 +1,179 @@ +//! Volume types for persistent storage in deployments. +//! +//! This module provides the [`Volume`] type for defining persistent storage +//! that can be mounted into deployment containers. + +use serde::{Deserialize, Serialize}; + +/// Persistent storage volume that can be mounted into deployments. +/// +/// Volumes allow data to persist across container restarts and can be +/// shared between deployments in the same region. +/// +/// # Example +/// +/// ```rust +/// use basilica_sdk_rust::Volume; +/// +/// // Reference an existing volume +/// let existing = Volume::from_name("my-data-vol", false); +/// +/// // Create a volume if it doesn't exist +/// let auto_create = Volume::from_name("new-volume", true); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Volume { + /// The unique name of the volume. + /// + /// Volume names must be unique within a user's namespace. + pub name: String, + + /// Whether to create the volume if it doesn't exist. + /// + /// If `true`, the volume will be automatically created during deployment + /// if no volume with this name exists. If `false`, deployment will fail + /// if the volume doesn't already exist. + pub create_if_missing: bool, +} + +impl Volume { + /// Reference or create a named volume. + /// + /// # Arguments + /// + /// * `name` - The unique name for the volume + /// * `create_if_missing` - If `true`, the volume will be created if it doesn't exist + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::Volume; + /// + /// // Create a volume that must already exist + /// let vol = Volume::from_name("existing-data", false); + /// + /// // Create a volume that will be auto-created if needed + /// let vol = Volume::from_name("auto-created-data", true); + /// ``` + pub fn from_name(name: impl Into, create_if_missing: bool) -> Self { + Self { + name: name.into(), + create_if_missing, + } + } + + /// Create a volume reference that must already exist. + /// + /// This is a convenience method equivalent to `Volume::from_name(name, false)`. + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::Volume; + /// + /// let vol = Volume::existing("my-data"); + /// assert!(!vol.create_if_missing); + /// ``` + pub fn existing(name: impl Into) -> Self { + Self::from_name(name, false) + } + + /// Create a volume reference that will be auto-created if it doesn't exist. + /// + /// This is a convenience method equivalent to `Volume::from_name(name, true)`. + /// + /// # Example + /// + /// ```rust + /// use basilica_sdk_rust::Volume; + /// + /// let vol = Volume::new("my-new-volume"); + /// assert!(vol.create_if_missing); + /// ``` + pub fn new(name: impl Into) -> Self { + Self::from_name(name, true) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_name_create_if_missing_true() { + let vol = Volume::from_name("test-vol", true); + assert_eq!(vol.name, "test-vol"); + assert!(vol.create_if_missing); + } + + #[test] + fn test_from_name_create_if_missing_false() { + let vol = Volume::from_name("test-vol", false); + assert_eq!(vol.name, "test-vol"); + assert!(!vol.create_if_missing); + } + + #[test] + fn test_existing_convenience_method() { + let vol = Volume::existing("existing-vol"); + assert_eq!(vol.name, "existing-vol"); + assert!(!vol.create_if_missing); + } + + #[test] + fn test_new_convenience_method() { + let vol = Volume::new("new-vol"); + assert_eq!(vol.name, "new-vol"); + assert!(vol.create_if_missing); + } + + #[test] + fn test_from_name_with_string() { + let name = String::from("string-vol"); + let vol = Volume::from_name(name, true); + assert_eq!(vol.name, "string-vol"); + } + + #[test] + fn test_clone() { + let vol = Volume::from_name("clone-test", true); + let cloned = vol.clone(); + assert_eq!(vol, cloned); + } + + #[test] + fn test_debug() { + let vol = Volume::from_name("debug-test", false); + let debug_str = format!("{:?}", vol); + assert!(debug_str.contains("debug-test")); + assert!(debug_str.contains("create_if_missing: false")); + } + + #[test] + fn test_serialization() { + let vol = Volume::from_name("ser-test", true); + let json = serde_json::to_string(&vol).expect("Serialization failed"); + assert!(json.contains("\"name\":\"ser-test\"")); + assert!(json.contains("\"create_if_missing\":true")); + } + + #[test] + fn test_deserialization() { + let json = r#"{"name":"deser-test","create_if_missing":false}"#; + let vol: Volume = serde_json::from_str(json).expect("Deserialization failed"); + assert_eq!(vol.name, "deser-test"); + assert!(!vol.create_if_missing); + } + + #[test] + fn test_equality() { + let vol1 = Volume::from_name("eq-test", true); + let vol2 = Volume::from_name("eq-test", true); + let vol3 = Volume::from_name("eq-test", false); + let vol4 = Volume::from_name("different", true); + + assert_eq!(vol1, vol2); + assert_ne!(vol1, vol3); + assert_ne!(vol1, vol4); + } +}