diff --git a/Cargo.lock b/Cargo.lock index 6fa41bd09..081dd860e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -738,28 +738,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "async-trait" version = "0.1.89" @@ -1674,12 +1652,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - [[package]] name = "const-oid" version = "0.10.1" @@ -1825,6 +1797,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32c" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +dependencies = [ + "rustc_version", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -2128,25 +2109,14 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "const-oid 0.9.6", - "pem-rfc7468 0.7.0", - "zeroize", -] - [[package]] name = "der" version = "0.8.0-rc.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9d8dd2f26c86b27a2a8ea2767ec7f9df7a89516e4794e54ac01ee618dda3aa4" dependencies = [ - "const-oid 0.10.1", - "pem-rfc7468 1.0.0-rc.3", + "const-oid", + "pem-rfc7468", "zeroize", ] @@ -2309,7 +2279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dac89f8a64533a9b0eaa73a68e424db0fb1fd6271c74cc0125336a05f090568d" dependencies = [ "block-buffer 0.11.0-rc.5", - "const-oid 0.10.1", + "const-oid", "crypto-common 0.2.0-rc.4", ] @@ -2461,7 +2431,7 @@ version = "3.0.0-rc.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ef49c0b20c0ad088893ad2a790a29c06a012b3f05bcfc66661fd22a94b32129" dependencies = [ - "pkcs8 0.11.0-rc.7", + "pkcs8", "serde", "signature 3.0.0-rc.4", ] @@ -3297,9 +3267,9 @@ dependencies = [ [[package]] name = "google-cloud-auth" -version = "0.17.2" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e57a13fbacc5e9c41ded3ad8d0373175a6b7a6ad430d99e89d314ac121b7ab06" +checksum = "1112c453c2e155b3e683204ffff52bcc6d6495d04b68d9e90cd24161270c5058" dependencies = [ "async-trait", "base64 0.21.7", @@ -3307,7 +3277,7 @@ dependencies = [ "google-cloud-token", "home", "jsonwebtoken", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "thiserror 1.0.69", @@ -3317,48 +3287,204 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "google-cloud-auth" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34f8aadacd3195fc3b08f2a5d582f2401c60d9f1598574acfcfb6228de25db29" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bytes", + "google-cloud-gax", + "http 1.3.1", + "reqwest 0.12.28", + "rustc_version", + "rustls 0.23.35", + "rustls-pki-types", + "serde", + "serde_json", + "thiserror 2.0.17", + "time", + "tokio", +] + +[[package]] +name = "google-cloud-gax" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b218292363f2e2d6ab8d6da4118acf91cc044439c442d2d6809b581e0728b377" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures", + "google-cloud-rpc", + "google-cloud-wkt", + "http 1.3.1", + "pin-project", + "rand 0.9.2", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", +] + +[[package]] +name = "google-cloud-gax-internal" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78125fa0347492177131d30c010e57ddce9bba1504c33be135f5853a9105c277" +dependencies = [ + "bytes", + "futures", + "google-cloud-auth 1.4.0", + "google-cloud-gax", + "google-cloud-rpc", + "google-cloud-wkt", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "opentelemetry-semantic-conventions", + "percent-encoding", + "pin-project", + "prost 0.14.3", + "prost-types", + "reqwest 0.12.28", + "rustc_version", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "tokio-stream", + "tonic 0.14.2", + "tonic-prost", + "tower", + "tracing", +] + +[[package]] +name = "google-cloud-iam-v1" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498a68e2a958e8aa9938f7db2c7147aad1b5a0ff2cd47c5ba4e10cb0dcb5bfc5" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-type", + "google-cloud-wkt", + "lazy_static", + "reqwest 0.12.28", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-longrunning" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c80938e704401a47fdf36b51ec10e1a99b1ec22793d607afd0e67c7b675b8b3" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-rpc", + "google-cloud-wkt", + "lazy_static", + "reqwest 0.12.28", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-lro" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49747b7b684b804a2d1040c2cdb21238b3d568a41ab9e36c423554509112f61d" +dependencies = [ + "google-cloud-gax", + "google-cloud-longrunning", + "google-cloud-rpc", + "google-cloud-wkt", + "serde", + "tokio", +] + [[package]] name = "google-cloud-metadata" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d901aeb453fd80e51d64df4ee005014f6cf39f2d736dd64f7239c132d9d39a6a" dependencies = [ - "reqwest 0.12.24", + "reqwest 0.12.28", "thiserror 1.0.69", "tokio", ] +[[package]] +name = "google-cloud-rpc" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd10e97751ca894f9dad6be69fcef1cb72f5bc187329e0254817778fc8235030" +dependencies = [ + "bytes", + "google-cloud-wkt", + "serde", + "serde_json", + "serde_with", +] + [[package]] name = "google-cloud-storage" -version = "0.24.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34a73d9e94d35665909050f02e035d8bdc82e419241b1b027ebf1ea51dc8a470" +checksum = "6abde5d51a4728f47b8f7781d7bf86ab51e310b42ec7c7c96578f1d03da938e4" dependencies = [ - "anyhow", - "async-stream", "async-trait", - "base64 0.21.7", + "base64 0.22.1", "bytes", - "futures-util", - "google-cloud-auth", - "google-cloud-metadata", - "google-cloud-token", + "chrono", + "crc32c", + "futures", + "google-cloud-auth 1.4.0", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-iam-v1", + "google-cloud-longrunning", + "google-cloud-lro", + "google-cloud-rpc", + "google-cloud-type", + "google-cloud-wkt", "hex", - "once_cell", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.7.0", + "lazy_static", + "md5", + "mime", "percent-encoding", - "pkcs8 0.10.2", - "regex", - "reqwest 0.12.24", - "reqwest-middleware 0.4.2", - "ring", + "pin-project", + "prost 0.14.3", + "prost-types", + "reqwest 0.12.28", "serde", "serde_json", + "serde_with", "sha2 0.10.9", - "thiserror 1.0.69", - "time", + "thiserror 2.0.17", "tokio", + "tokio-stream", + "tonic 0.14.2", "tracing", "url", + "uuid", ] [[package]] @@ -3370,6 +3496,35 @@ dependencies = [ "async-trait", ] +[[package]] +name = "google-cloud-type" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9390ac2f3f9882ff42956b25ea65b9f546c8dd44c131726d75a96bf744ec75f6" +dependencies = [ + "bytes", + "google-cloud-wkt", + "serde", + "serde_json", + "serde_with", +] + +[[package]] +name = "google-cloud-wkt" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6f270e404be7ce76a3260abe0c3c71492ab2599ccd877f3253f3dd552f48cc9" +dependencies = [ + "base64 0.22.1", + "bytes", + "serde", + "serde_json", + "serde_with", + "thiserror 2.0.17", + "time", + "url", +] + [[package]] name = "governor" version = "0.6.3" @@ -3646,7 +3801,7 @@ dependencies = [ "num_cpus", "rand 0.8.5", "regex", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "sha1 0.10.6", @@ -3943,6 +4098,19 @@ dependencies = [ "webpki-roots 1.0.3", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper 1.7.0", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -4389,10 +4557,10 @@ dependencies = [ "netwatch", "pin-project", "pkarr", - "pkcs8 0.11.0-rc.7", + "pkcs8", "portmapper", "rand 0.9.2", - "reqwest 0.12.24", + "reqwest 0.12.28", "rustls 0.23.35", "rustls-pki-types", "rustls-platform-verifier 0.5.3", @@ -4649,7 +4817,7 @@ dependencies = [ "pkarr", "postcard", "rand 0.9.2", - "reqwest 0.12.24", + "reqwest 0.12.28", "rustls 0.23.35", "rustls-pki-types", "serde", @@ -5274,6 +5442,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + [[package]] name = "memchr" version = "2.7.6" @@ -6147,7 +6321,7 @@ dependencies = [ "bytes", "http 1.3.1", "opentelemetry 0.28.0", - "reqwest 0.12.24", + "reqwest 0.12.28", "tracing", ] @@ -6164,8 +6338,8 @@ dependencies = [ "opentelemetry-http", "opentelemetry-proto", "opentelemetry_sdk", - "prost", - "reqwest 0.12.24", + "prost 0.13.5", + "reqwest 0.12.28", "thiserror 2.0.17", "tracing", ] @@ -6178,10 +6352,16 @@ checksum = "56f8870d3024727e99212eb3bb1762ec16e255e3e6f58eeb3dc8db1aa226746d" dependencies = [ "opentelemetry 0.28.0", "opentelemetry_sdk", - "prost", - "tonic", + "prost 0.13.5", + "tonic 0.12.3", ] +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e62e29dfe041afb8ed2a6c9737ab57db4907285d999ef8ad3a59092a36bdc846" + [[package]] name = "opentelemetry_sdk" version = "0.28.0" @@ -6344,15 +6524,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - [[package]] name = "pem-rfc7468" version = "1.0.0-rc.3" @@ -6438,7 +6609,7 @@ dependencies = [ "log", "lru 0.13.0", "ntimestamp", - "reqwest 0.12.24", + "reqwest 0.12.28", "self_cell", "serde", "sha1_smol", @@ -6450,24 +6621,14 @@ dependencies = [ "wasm-bindgen-futures", ] -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der 0.7.10", - "spki 0.7.3", -] - [[package]] name = "pkcs8" version = "0.11.0-rc.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93eac55f10aceed84769df670ea4a32d2ffad7399400d41ee1c13b1cd8e1b478" dependencies = [ - "der 0.8.0-rc.9", - "spki 0.8.0-rc.4", + "der", + "spki", ] [[package]] @@ -6804,7 +6965,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.13.5", +] + +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive 0.14.3", ] [[package]] @@ -6820,6 +6991,28 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost 0.14.3", +] + [[package]] name = "psyche-centralized-client" version = "0.2.0" @@ -6827,8 +7020,10 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "bytes", "clap", "clap-markdown", + "google-cloud-storage", "hex", "hf-hub", "psyche-centralized-shared", @@ -6977,6 +7172,7 @@ name = "psyche-coordinator" version = "0.2.0" dependencies = [ "anchor-lang", + "anyhow", "async-trait", "bytemuck", "cfg_eval", @@ -7015,9 +7211,12 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "bytes", "chrono", "clap", "futures", + "google-cloud-auth 0.16.0", + "google-cloud-gax", "google-cloud-storage", "hf-hub", "memmap2 0.9.8", @@ -7032,7 +7231,7 @@ dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "rayon", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "static-web-server", @@ -7044,6 +7243,7 @@ dependencies = [ "tokio-util 0.7.16", "tracing", "ts-rs", + "urlencoding", ] [[package]] @@ -7280,6 +7480,8 @@ dependencies = [ "async-trait", "clap", "clap-markdown", + "google-cloud-storage", + "hf-hub", "psyche-client", "psyche-coordinator", "psyche-core", @@ -7957,9 +8159,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.24" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", @@ -8020,21 +8222,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "reqwest-middleware" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57f17d28a6e6acfe1733fe24bcd30774d13bffa4b8a22535b4c8c98423088d4e" -dependencies = [ - "anyhow", - "async-trait", - "http 1.3.1", - "reqwest 0.12.24", - "serde", - "thiserror 1.0.69", - "tower-service", -] - [[package]] name = "resolv-conf" version = "0.7.5" @@ -8237,9 +8424,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.12.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -8572,15 +8759,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -10039,7 +10226,7 @@ dependencies = [ "indicatif", "log", "reqwest 0.11.27", - "reqwest-middleware 0.2.5", + "reqwest-middleware", "semver", "serde", "serde_derive", @@ -10064,7 +10251,7 @@ dependencies = [ "bs58", "jsonrpc-core", "reqwest 0.11.27", - "reqwest-middleware 0.2.5", + "reqwest-middleware", "semver", "serde", "serde_derive", @@ -10910,16 +11097,6 @@ dependencies = [ "lock_api", ] -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der 0.7.10", -] - [[package]] name = "spki" version = "0.8.0-rc.4" @@ -10927,7 +11104,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8baeff88f34ed0691978ec34440140e1572b68c7dd4a495fd14a3dc1944daa80" dependencies = [ "base64ct", - "der 0.8.0-rc.9", + "der", ] [[package]] @@ -12300,13 +12477,51 @@ dependencies = [ "http-body-util", "percent-encoding", "pin-project", - "prost", + "prost 0.13.5", "tokio-stream", "tower-layer", "tower-service", "tracing", ] +[[package]] +name = "tonic" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" +dependencies = [ + "base64 0.22.1", + "bytes", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "rustls-native-certs", + "sync_wrapper 1.0.2", + "tokio", + "tokio-rustls 0.26.4", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-prost" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66bd50ad6ce1252d87ef024b3d64fe4c3cf54a86fb9ef4c631fdd0ded7aeaa67" +dependencies = [ + "bytes", + "prost 0.14.3", + "tonic 0.14.2", +] + [[package]] name = "torch-sys" version = "0.22.0" @@ -12326,18 +12541,22 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", + "indexmap 2.11.4", "pin-project-lite", + "slab", "sync_wrapper 1.0.2", "tokio", + "tokio-util 0.7.16", "tower-layer", "tower-service", + "tracing", ] [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags 2.9.4", "bytes", @@ -12365,9 +12584,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -12377,9 +12596,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -12388,9 +12607,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", @@ -12832,7 +13051,7 @@ dependencies = [ "env_logger 0.11.8", "graphql_client", "impl_from_tuple", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "thiserror 1.0.69", @@ -13973,6 +14192,12 @@ dependencies = [ "zstd 0.11.2+zstd.1.5.2", ] +[[package]] +name = "zmij" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" + [[package]] name = "zstd" version = "0.11.2+zstd.1.5.2" diff --git a/Cargo.toml b/Cargo.toml index 7b3968ade..f7a99c2b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ indicatif = "0.17.5" tokenizers = { version = "0.20.0", default-features = false, features = [ "onig", ] } +google-cloud-storage = "1.6.0" tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } torch-sys = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } pyo3-tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } diff --git a/architectures/centralized/client/Cargo.toml b/architectures/centralized/client/Cargo.toml index fbc6ca39a..7d7f51d80 100644 --- a/architectures/centralized/client/Cargo.toml +++ b/architectures/centralized/client/Cargo.toml @@ -25,6 +25,8 @@ time.workspace = true bytemuck.workspace = true clap-markdown.workspace = true hex = "0.4.3" +bytes.workspace = true +google-cloud-storage.workspace = true psyche-python-extension-impl = { workspace = true, optional = true } [features] diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 8f6a9bcca..cf4f00d69 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -1,13 +1,13 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; -use hf_hub::Repo; +use google_cloud_storage::client::{Storage, StorageControl}; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; -use psyche_client::HubUploadInfo; -use psyche_client::UploadInfo; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; -use psyche_coordinator::{Coordinator, HealthChecks, model}; +use psyche_client::{GcsUploadInfo, HubUploadInfo, UploadInfo}; +use psyche_coordinator::model::Checkpoint; +use psyche_coordinator::{Coordinator, HealthChecks}; use psyche_metrics::ClientMetrics; use psyche_network::{ AuthenticatableIdentity, EndpointId, NetworkTUIState, NetworkTui, SecretKey, TcpClient, @@ -31,7 +31,7 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::Checkpoint), + Checkpoint(Checkpoint), } struct Backend { @@ -69,7 +69,7 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: Checkpoint) -> Result<()> { self.tx.send(ToSend::Checkpoint(checkpoint))?; Ok(()) } @@ -84,6 +84,7 @@ pub struct App { server_conn: TcpClient, metrics: Arc, + skip_upload_check: bool, } pub async fn build_app( @@ -91,6 +92,7 @@ pub async fn build_app( server_addr: String, tx_tui_state: Option>, p: TrainArgs, + is_test: bool, ) -> Result<( App, allowlist::AllowDynamic, @@ -162,6 +164,7 @@ pub async fn build_app( server_conn, run_id: p.run_id, metrics, + skip_upload_check: is_test, }; Ok((app, allowlist, p2p, state_options)) } @@ -173,22 +176,59 @@ impl App { p2p: NC, state_options: RunInitConfig, ) -> Result<()> { - // sanity checks - if let Some(checkpoint_config) = &state_options.checkpoint_config { - if let Some(UploadInfo::Hub(HubUploadInfo { - hub_repo, - hub_token, - })) = &checkpoint_config.upload_info - { - let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_token.clone())) - .build()?; - let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model)); - if !repo_api.is_writable().await { + // Sanity checks using the checkpoint config from state_options, not the zeroed coordinator state. + // The coordinator_state is only populated after receiving the first ServerToClientMessage::Coordinator. + if !self.skip_upload_check { + let upload_info = match &state_options.checkpoint_config { + config if config.skip_upload => Some(UploadInfo::Dummy()), + config => { + // Use HF_TOKEN from checkpoint_config for Hub uploads + if let Some(ref hub_token) = config.hub_token { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: String::new(), // Will be validated when actual checkpoint is received + hub_token: hub_token.clone(), + })) + } else { + // Check if GCS credentials are available by attempting to create a client + match Storage::builder().build().await { + Ok(_) => Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: String::new(), // Will be validated when actual checkpoint is received + gcs_prefix: None, + })), + Err(_) => None, + } + } + } + }; + + match upload_info { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: _, + hub_token, + })) => { + let _api = hf_hub::api::tokio::ApiBuilder::new() + .with_token(Some(hub_token.clone())) + .build()?; + } + Some(UploadInfo::Gcs(_gcs_info)) => { + let _storage = Storage::builder() + .build() + .await + .map_err(|e| anyhow::anyhow!("Failed to create GCS client: {}", e))?; + + let _storage_control = + StorageControl::builder().build().await.map_err(|e| { + anyhow::anyhow!("Failed to create GCS control client: {}", e) + })?; + // GCS credentials are valid - actual bucket writability will be checked during checkpoint + } + Some(UploadInfo::Dummy()) => { + // In test mode or skip_upload mode, we skip upload checks + } + None => { anyhow::bail!( - "Checkpoint upload repo {} is not writable with the passed API key.", - hub_repo - ) + "No upload credentials found for checkpointing. Set HF_TOKEN for HuggingFace Hub or configure GCS credentials." + ); } } } diff --git a/architectures/centralized/client/src/main.rs b/architectures/centralized/client/src/main.rs index fad5d3817..4a0a7d213 100644 --- a/architectures/centralized/client/src/main.rs +++ b/architectures/centralized/client/src/main.rs @@ -105,7 +105,7 @@ async fn async_main() -> Result<()> { )?; let (mut app, allowlist, p2p, state_options) = - build_app(cancel, server_addr, tx_tui_state, args) + build_app(cancel, server_addr, tx_tui_state, args, false) .await .unwrap(); diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 402146817..e203189b1 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -1,7 +1,7 @@ use anyhow::{Result, anyhow, bail}; use async_trait::async_trait; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; -use psyche_coordinator::model::{self, Checkpoint, LLM, LLMTrainingDataLocation, Model}; +use psyche_coordinator::model::{Checkpoint, LLM, LLMTrainingDataLocation, Model}; use psyche_coordinator::{ Client, ClientState, Coordinator, CoordinatorError, HealthChecks, Round, RunState, SOLANA_MAX_NUM_CLIENTS, TickResult, @@ -81,7 +81,7 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: Checkpoint) -> Result<()> { bail!("Server does not send checkpoints"); } } @@ -402,6 +402,9 @@ impl App { Self::get_timestamp(), rand::rng().next_u64(), ), + OpportunisticData::CooldownStep(witness) => { + self.coordinator.cooldown_witness(&from, witness) + } } { warn!("Error when processing witness: {error}"); }; diff --git a/architectures/centralized/testing/src/client.rs b/architectures/centralized/testing/src/client.rs index 41c156e6c..7f0350c90 100644 --- a/architectures/centralized/testing/src/client.rs +++ b/architectures/centralized/testing/src/client.rs @@ -33,6 +33,7 @@ impl Client { client_app_params.server_addr, None, client_app_params.train_args, + true, ) .await .unwrap(); @@ -57,6 +58,7 @@ impl Client { client_app_params.server_addr, None, client_app_params.train_args, + true, ) .await .unwrap(); diff --git a/architectures/centralized/testing/src/lib.rs b/architectures/centralized/testing/src/lib.rs index 3ff0d7a21..58b9429a4 100644 --- a/architectures/centralized/testing/src/lib.rs +++ b/architectures/centralized/testing/src/lib.rs @@ -6,4 +6,4 @@ pub mod test_utils; pub const WARMUP_TIME: u64 = 60; pub const MAX_ROUND_TRAIN_TIME: u64 = 5; pub const ROUND_WITNESS_TIME: u64 = 2; -pub const COOLDOWN_TIME: u64 = 3; +pub const COOLDOWN_TIME: u64 = 5; diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index 1ec700b4e..ae835da36 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -127,6 +127,7 @@ pub fn dummy_client_app_params_with_training_delay( run_id: &str, training_delay_secs: u64, ) -> AppParams { + std::env::set_var("HF_TOKEN", "dummy_token"); AppParams { cancel: CancellationToken::default(), server_addr: format!("localhost:{server_port}").to_string(), @@ -141,6 +142,7 @@ pub fn dummy_client_app_params_with_training_delay( "--max-concurrent-parameter-requests", "10", "--hub-max-concurrent-downloads", "1", "--dummy-training-delay-secs", training_delay_secs.to_string().as_str(), + "--skip-checkpoint-upload", ]) .train_args, } diff --git a/architectures/decentralized/justfile b/architectures/decentralized/justfile index 77ba7d4ed..ad0b89a14 100644 --- a/architectures/decentralized/justfile +++ b/architectures/decentralized/justfile @@ -5,6 +5,8 @@ set working-directory := '../../' # In case a recipe is not found here, it will fallback to the root justfile. AUTHORIZER := env_var_or_default("AUTHORIZER", "11111111111111111111111111111111") +HF_TOKEN := env_var_or_default("HF_TOKEN", "") +GOOGLE_APPLICATION_CREDENTIALS := env_var_or_default("GOOGLE_APPLICATION_CREDENTIALS", "") set fallback := true @@ -37,10 +39,16 @@ setup-solana-localnet-permissioned-light-test-run-treasurer run_id="test" *args= RUN_ID={{ run_id }} CONFIG_FILE=./config/solana-test/light-config.toml ./scripts/deploy-solana-test.sh --treasurer {{ args }} start-training-localnet-client run_id="test" *args='': - AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="false" RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} start-training-localnet-light-client run_id="test" *args='': - AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="false" RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} + +start-training-localnet-light-client-checkpoint run_id="test" *args='': + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="true" RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} + +start-training-localnet-client-checkpoint run_id="test" *args='': + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="true" RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} OTLP_METRICS_URL := "http://localhost:4318/v1/metrics" OTLP_LOGS_URL := "http://localhost:4318/v1/logs" diff --git a/architectures/decentralized/solana-client/Cargo.toml b/architectures/decentralized/solana-client/Cargo.toml index 752e3af4f..46a0d2d1e 100644 --- a/architectures/decentralized/solana-client/Cargo.toml +++ b/architectures/decentralized/solana-client/Cargo.toml @@ -29,6 +29,8 @@ time.workspace = true tokio.workspace = true tokio-util.workspace = true tracing.workspace = true +google-cloud-storage.workspace = true +hf-hub.workspace = true psyche-python-extension-impl = { workspace = true, optional = true } [features] diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 36a529bbb..e4e384286 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -1,4 +1,6 @@ use crate::network_identity::NetworkIdentity; +use google_cloud_storage::client::StorageControl; +use hf_hub::Repo; use psyche_solana_rpc::SolanaBackend; use anchor_client::{ @@ -11,9 +13,13 @@ use anchor_client::{ }; use anyhow::{Result, anyhow}; use psyche_client::{ - Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, + Client, ClientTUI, ClientTUIState, GcsUploadInfo, HubUploadInfo, NC, RunInitConfig, TrainArgs, + UploadInfo, read_identity_secret_key, +}; +use psyche_coordinator::{ + ClientState, Coordinator, CoordinatorError, RunState, + model::{self, GcsRepo, HubRepo, LLM, Model}, }; -use psyche_coordinator::{ClientState, Coordinator, CoordinatorError, RunState}; use psyche_core::sha256; use psyche_metrics::ClientMetrics; @@ -226,6 +232,77 @@ impl App { let mut joined_run_this_epoch = None; let mut ever_joined_run = false; + // sanity checks + let Model::LLM(LLM { checkpoint, .. }) = start_coordinator_state.model; + let upload_info = match checkpoint { + model::Checkpoint::Hub(HubRepo { repo_id, revision }) + | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (&repo_id).into(), + hub_token: (&revision.unwrap_or_default()).into(), + })) + } + model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (&bucket).into(), + gcs_prefix: Some((&prefix.unwrap_or_default()).into()), + })) + } + _ => None, + }; + match upload_info { + Some(UploadInfo::Hub(hub_info)) => { + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_token(Some(hub_info.hub_token)) + .build()?; + let repo_api = api.repo(Repo::new( + hub_info.hub_repo.clone(), + hf_hub::RepoType::Model, + )); + if !repo_api.is_writable().await { + anyhow::bail!( + "Checkpoint upload repo {} is not writable with the passed API key.", + hub_info.hub_repo + ) + } + } + Some(UploadInfo::Gcs(gcs_info)) => { + let client = StorageControl::builder().build().await?; + + let permissions_to_test = vec![ + "storage.objects.list", + "storage.objects.get", + "storage.objects.create", + "storage.objects.delete", + ]; + + let resource = format!("projects/_/buckets/{}", gcs_info.gcs_bucket); + let perms_vec: Vec = + permissions_to_test.iter().map(|s| s.to_string()).collect(); + let response = client + .test_iam_permissions() + .set_resource(&resource) + .set_permissions(perms_vec) + .send() + .await?; + + let correct_permissions = permissions_to_test + .into_iter() + .all(|p| response.permissions.contains(&p.to_string())); + if !correct_permissions { + anyhow::bail!( + "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly and have the correct permissions to the bucket.", + gcs_info.gcs_bucket + ) + } + } + Some(UploadInfo::Dummy()) => { + // In test mode, we skip upload checks + } + None => {} + } + // if we're already in "WaitingForMembers" we won't get an update saying that // (subscription is on change), so check if it's in that state right at boot // and join the run if so diff --git a/architectures/decentralized/solana-common/src/backend.rs b/architectures/decentralized/solana-common/src/backend.rs index 47a2b7dc9..d77854aeb 100644 --- a/architectures/decentralized/solana-common/src/backend.rs +++ b/architectures/decentralized/solana-common/src/backend.rs @@ -19,7 +19,7 @@ use anchor_client::{ }; use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; -use psyche_coordinator::model::{self, Checkpoint}; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_core::IntegrationTestLogMarker; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; @@ -309,6 +309,12 @@ impl SolanaBackend { &user, witness, ), + OpportunisticData::CooldownStep(witness) => instructions::coordinator_cooldown_witness( + &coordinator_instance, + &coordinator_account, + &user, + witness, + ), }; self.spawn_scheduled_send("Witness", &[instruction], &[]); } @@ -605,7 +611,7 @@ impl WatcherBackend for SolanaBackendRunner Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: Checkpoint) -> Result<()> { self.backend .send_checkpoint(self.instance, self.account, checkpoint); Ok(()) diff --git a/architectures/decentralized/solana-common/src/instructions.rs b/architectures/decentralized/solana-common/src/instructions.rs index b535d54f4..0e8be8d85 100644 --- a/architectures/decentralized/solana-common/src/instructions.rs +++ b/architectures/decentralized/solana-common/src/instructions.rs @@ -179,6 +179,28 @@ pub fn coordinator_warmup_witness( ) } +pub fn coordinator_cooldown_witness( + coordinator_instance: &Pubkey, + coordinator_account: &Pubkey, + user: &Pubkey, + witness: psyche_coordinator::Witness, +) -> Instruction { + anchor_instruction( + psyche_solana_coordinator::ID, + psyche_solana_coordinator::accounts::PermissionlessCoordinatorAccounts { + user: *user, + coordinator_instance: *coordinator_instance, + coordinator_account: *coordinator_account, + }, + psyche_solana_coordinator::instruction::CooldownWitness { + proof: witness.proof, + participant_bloom: witness.participant_bloom, + broadcast_bloom: witness.broadcast_bloom, + broadcast_merkle: witness.broadcast_merkle, + }, + ) +} + pub fn coordinator_health_check( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, diff --git a/architectures/decentralized/solana-coordinator/Cargo.lock b/architectures/decentralized/solana-coordinator/Cargo.lock index 72a38df53..ed360578c 100644 --- a/architectures/decentralized/solana-coordinator/Cargo.lock +++ b/architectures/decentralized/solana-coordinator/Cargo.lock @@ -1605,6 +1605,7 @@ name = "psyche-coordinator" version = "0.2.0" dependencies = [ "anchor-lang", + "anyhow", "async-trait", "bytemuck", "cfg_eval", diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 8d9ddb977..79eadddef 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -233,6 +233,20 @@ impl CoordinatorInstanceState { self.tick() } + pub fn cooldown_witness( + &mut self, + payer: &Pubkey, + witness: Witness, + ) -> Result<()> { + let id = self.clients_state.find_signer(payer)?; + + self.coordinator + .cooldown_witness(id, witness) + .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; + + self.tick() + } + pub fn warmup_witness( &mut self, payer: &Pubkey, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 0a041a6e9..e7fd2bee9 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -268,7 +268,6 @@ pub mod psyche_solana_coordinator { ) } - #[allow(unused_variables)] // for the metadata field. adding a _ prefix results in anchor's IDL not matching the actual types. lol. pub fn warmup_witness( ctx: Context, proof: WitnessProof, @@ -289,6 +288,26 @@ pub mod psyche_solana_coordinator { ) } + pub fn cooldown_witness( + ctx: Context, + proof: WitnessProof, + participant_bloom: WitnessBloom, + broadcast_bloom: WitnessBloom, + broadcast_merkle: MerkleRoot, + ) -> Result<()> { + let mut account = ctx.accounts.coordinator_account.load_mut()?; + account.increment_nonce(); + account.state.cooldown_witness( + ctx.accounts.user.key, + Witness { + proof, + participant_bloom, + broadcast_bloom, + broadcast_merkle, + }, + ) + } + pub fn health_check( ctx: Context, id: ClientId, diff --git a/architectures/decentralized/solana-treasurer/Cargo.lock b/architectures/decentralized/solana-treasurer/Cargo.lock index 5d56eb74d..06fc99a96 100644 --- a/architectures/decentralized/solana-treasurer/Cargo.lock +++ b/architectures/decentralized/solana-treasurer/Cargo.lock @@ -1605,6 +1605,7 @@ name = "psyche-coordinator" version = "0.1.0" dependencies = [ "anchor-lang", + "anyhow", "async-trait", "bytemuck", "cfg_eval", diff --git a/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml b/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml index 71ab21789..1b4f00c8f 100644 --- a/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml +++ b/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml @@ -24,7 +24,7 @@ max_seq_len = 2048 cold_start_warmup_steps = 0 [model.LLM.data_location] Server = "127.0.0.1:20001" -[model.LLM.checkpoint.Hub] +[model.LLM.checkpoint.Dummy] repo_id = "emozilla/llama2-20m-init" [model.LLM.lr_schedule.Cosine] base_lr = 4.0e-4 diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index c275feea3..0daeab279 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -1,7 +1,7 @@ [config] warmup_time = 50 -cooldown_time = 30 epoch_time = 60 +cooldown_time = 30 max_round_train_time = 15 round_witness_time = 1 min_clients = 1 diff --git a/psyche-book/src/enduser/join-run.md b/psyche-book/src/enduser/join-run.md index b79678bcd..f56df1be7 100644 --- a/psyche-book/src/enduser/join-run.md +++ b/psyche-book/src/enduser/join-run.md @@ -67,6 +67,11 @@ WS_RPC=wss://your-primary-rpc-provider.com # Required: Which run id to join RUN_ID=your_run_id_here +# Required: access token to write model states in the storage. +# Depending on the config this will be a HuggingFace token or Google cloud storage access file path, only one of them should be needed +HF_TOKEN=HuggingFace_write_token_for_repo +GOOGLE_APPLICATION_CREDENTIALS=path_to_credentials_file_with_access_to_bucket + # Recommended: Fallback RPC Endpoints (for reliability) RPC_2=https://your-backup-rpc-provider.com WS_RPC_2=wss://your-backup-rpc-provider.com diff --git a/psyche-book/src/enduser/run-config.md b/psyche-book/src/enduser/run-config.md index 8440dd0a7..f5c2f6635 100644 --- a/psyche-book/src/enduser/run-config.md +++ b/psyche-book/src/enduser/run-config.md @@ -16,7 +16,7 @@ Here's a sample config with some of its options documented. # maximum time, in seconds, to let nodes download the model from a checkpoint / other nodes warmup_time = 30 -# time, in seconds, to let nodes bring the model from the GPU to disk, and to opt to join the next round. +# time, in seconds, to let nodes bring the model from the GPU to disk, upload the model to the remote storage and to opt to join the next round. cooldown_time = 30 # time, in seconds, that an epoch will last. @@ -74,6 +74,10 @@ max_seq_len = 2048 # Repo where the model is located in HugggingFace, will be used to download the model at the beginning of training. repo_id = "emozilla/llama2-20m-init" +# Google Cloud Storage is also supported +[model.LLM.checkpoint.Gcs] +bucket = "bucket_name" + [model.LLM.data_location.Http] # Token size in bytes, can be "TwoBytes" or "FourBytes" token_size_in_bytes = "TwoBytes" diff --git a/psyche-book/src/explain/general-workflow.md b/psyche-book/src/explain/general-workflow.md index 4ac8bfa63..d3fc61662 100644 --- a/psyche-book/src/explain/general-workflow.md +++ b/psyche-book/src/explain/general-workflow.md @@ -95,18 +95,27 @@ Any clients that have failed [health checks](#health-checks) will also be remove ### Cooldown phase (state: Cooldown) -The _Cooldown_ phase is the last phase of an epoch, during which the Coordinator waits the _Cooldown_ period to elapse. At this point the clients will begin to do a new checkpoint of the model, this is saving the state of the model at that time to a external storage, such as a Hugging Face. +The **Cooldown** phase is the last phase of an epoch. At this point, clients begin creating a new checkpoint of the model. This means saving the current state of the model to external storage, such as Hugging Face or a bucket in Google Cloud Storage (GCS). -When the _Cooldown_ phase begins, the Coordinator also resets the current model checkpoint state to `Checkpoint::P2P`, indicating that new joiners should download the latest copy of the model from the other participants and not from the usual checkpoint. +At the beginning of this state, the run elects a subset of clients that will be designated as **checkpointers**. All clients are potential checkpointers: one third of the total clients in the run will be elected pseudo-randomly at this stage. If a client is elected, it will start uploading the model state to the storage declared in the run configuration by the run owner. -Upon exiting the _Cooldown_ phase, the Coordinator transitions to the next epoch, saving the previous epoch state, and moving back to the _WaitingForMembers_ phase. All the clients that were participating in the previous epoch automatically join to the new epoch unless they exit manually. +The client that finishes uploading the model sends a transaction to the coordinator, called the **opportunistic cooldown**, indicating that the entire model was uploaded successfully. + +There are two ways the coordinator can transition from this state to the next one: + +- As soon as the first opportunistic cooldown transaction arrives, the coordinator moves to the next state and cancels all upload tasks from the remaining clients, since it already knows that at least one checkpointer has uploaded the complete model correctly. +- If no transaction is received, there is a maximum cooldown time defined in the run configuration. If this time is reached, the coordinator will move to the next state even if no new checkpoint was produced. + +When the _Cooldown_ phase begins, the coordinator also resets the current model checkpoint state to `Checkpoint::P2P`, indicating that new joiners should download the latest copy of the model from other participants rather than from the usual checkpoint storage. + +Upon exiting the _Cooldown_ phase, the coordinator transitions to the next epoch, saving the previous epoch state and moving back to the _WaitingForMembers_ phase. All clients that participated in the previous epoch automatically join the new epoch unless they exit manually. ### It all comes together Here's is an overview of how the state of the run can change depending on the situation: ```mermaid -%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'35px'}}}%% +%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'45px'}}}%% flowchart LR WFM((Waiting For Members)) W((Warmup)) @@ -119,6 +128,8 @@ flowchart LR d{Witness quorum reached} e{Max training time passed} f{End of the epoch reached} + g{Client checkpoints} + h{Max cooldown time passed} WFM --> a a -->|Yes| W @@ -135,7 +146,10 @@ flowchart LR WI --> f f -->|Yes| CD f -->|No| T - CD --> WFM + CD -->g + g -->|Yes| WFM + g -->|No|h + h -->|Yes| WFM ``` And this is how it fits with real the real clients and how they interact in each of the stages. The committee in this case is the structure that contains all the witness data for the round. diff --git a/psyche-book/src/explain/index.md b/psyche-book/src/explain/index.md index 16d6eb6aa..a2e6baa27 100644 --- a/psyche-book/src/explain/index.md +++ b/psyche-book/src/explain/index.md @@ -62,7 +62,7 @@ These three phases constitute a **round** of training and will be looping until At the start of an **epoch**, all clients have a window of time to join the run by requesting to be added by coordinator, and then connecting to the other participating clients. This state will be known as the _Waiting for Members_ phase. -Once a minimum threshold of clients has been met, the run will transition to the _Warmup_ phase and begin a countdown to allow connected clients to update their copy of the model. To obtain a copy of the model, the Coordinator will either direct clients to a checkpoint uploaded somewhere like HuggingFace and they will have to download it from there or direct clients to [download the model from other clients](./model-sharing.md) via the p2p network. In the first epoch, all clients will download the model from HuggingFace and after that every new epoch, clients will download the model from other clients via the p2p network. +Once a minimum threshold of clients has been met, the run will transition to the _Warmup_ phase and begin a countdown to allow connected clients to update their copy of the model. To obtain a copy of the model, the Coordinator will either direct clients to a checkpoint uploaded somewhere like HuggingFace or Google Cloud Storage and they will have to download it from there or direct clients to [download the model from other clients](./model-sharing.md) via the p2p network. In the first epoch, all clients will download the model from the external storage and after that every new epoch, clients will download the model from other clients via the p2p network. After the _Warmup_ phase ends, it will enter the _Training_ phase. @@ -84,7 +84,7 @@ At the start of each round, one or more clients are randomly selected as witness These bloom filters are sent to the coordinator, which then combines them into a provable consensus of which results to apply to the model. -Once a witness quorum is reached, the coordinator advances to the _Training_ phase to allow all clients a brief window to download every training result of the previous round, clients are assigned new data, and the process repeats. After a fixed amount of time, a _Cooldown_ round occurs, marking the end of an **epoch**. This time is configurable in the run creation process that we'll explore in the other sections. +Once a witness quorum is reached, the coordinator advances to the _Training_ phase to allow all clients a brief window to download every training result of the previous round, clients are assigned new data, and the process repeats. After a fixed amount of time, a _Cooldown_ round occurs, marking the end of an **epoch**. At this state, one third of the clients are randomly selected as checkpointers and all of them starts uploading the state of the model to an external storage. There's a maximum amount of time for staying in this state, this time is configurable in the run creation process that we'll explore in the other sections. ## The witness/train loop visualized diff --git a/psyche-book/src/explain/model-sharing.md b/psyche-book/src/explain/model-sharing.md index ba61ebb12..5bc0bb8cd 100644 --- a/psyche-book/src/explain/model-sharing.md +++ b/psyche-book/src/explain/model-sharing.md @@ -6,15 +6,19 @@ At the beginning of a run, all clients must download the model parameters, token Each client will then modify their copy of the model by receiving new training results from other clients and applying them. This keeps everyone's copy of model identical within an **epoch** without an additional full synchronization step. -When a new client joins a run that has already progressed past its first epoch, it would not be correct for the client to download the original model from HuggingFace, as the model parameters would have already been updated during training. Instead, the new client must acquire a copy of the model from the peers who have been actively training it. +When a new client joins a run that has already progressed past its first epoch, it would not be correct for the client to download the original model from the external storage, as the model parameters would have already been updated during training. Instead, the new client must acquire a copy of the model from the peers who have been actively training it. This synchronization process occurs during the _Warmup_ phase, while the coordinator waits to begin the next _Training_ phase. -To address this, we **checkpoint** the model at the end of an **epoch**, where clients save and share the entire model for new peers to join. There are two checkpointing variants: HuggingFace based and P2P based. +To address this, we **checkpoint** the model at the end of an **epoch**, where clients save and share the entire model for new peers to join. There are three checkpointing variants: HuggingFace based, Google Cloud Storage based and P2P based. ## HuggingFace checkpoint -In this approach, a client or a set of clients can optionally run as **checkpointers** if they declare a checkpoint URL when joining the run. These clients upload their copy of updated model to HuggingFace after each epoch, and send the URL for this checkpoint to the coordinator. When a new client joins the run, it retrieves the checkpoint URL from the coordinator, and connects to HuggingFace to download the latest copy of the model parameters and configuration files. +In this approach, a client or a set of clients will be elected randomly as **checkpointers**. These clients upload their copy of updated model to HuggingFace at Cooldown state at the end of the epoch. The model will be uploaded to the HuggingFace repository that is declared in the run configuration by the run owner. When a new client joins the run it connects to HuggingFace to download the latest copy of the model parameters and configuration files. + +## Google Cloud Storage checkpoint + +Very similar to the previous approach but based on Google Cloud Storage bucket. Every elected checkpointer will upload the model at the end of an epoch. The bucket name is declared by the run owner in the initial configuration. If a client joins the run, it connects to the GCS and download the model parameters and configuration files. ## P2P checkpoint diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index 55699e591..41c1654be 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -20,15 +20,23 @@ elif [[ -z "${WALLET_FILE:-}" ]]; then trap "echo 'Cleaning up ephemeral wallet file...'; rm -f '${WALLET_FILE}'" EXIT fi -RPC=${RPC:-"http://127.0.0.1:8899"} -WS_RPC=${WS_RPC:-"ws://127.0.0.1:8900"} +RPC=${RPC:-"http://localhost:8899"} +WS_RPC=${WS_RPC:-"ws://localhost:8900"} RUN_ID=${RUN_ID:-"test"} AUTHORIZER=${AUTHORIZER:-"11111111111111111111111111111111"} +if [[ "$CHECKPOINT" == true ]]; then + echo -e "\n[+] Starting Solana training with checkpointing enabled..." +else + echo -e "\n[+] Starting Solana training without checkpointing..." +fi + # presets for a DGX or an HGX DP=${DP:-"8"} TP=${TP:-"1"} BATCH_SIZE=${BATCH_SIZE:-"1"} +HF_TOKEN=${HF_TOKEN:-""} +GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-""} # fine if this fails solana airdrop 10 "$(solana-keygen pubkey ${WALLET_FILE})" --url "${RPC}" || true @@ -36,7 +44,7 @@ solana airdrop 10 "$(solana-keygen pubkey ${WALLET_FILE})" --url "${RPC}" || tru export RUST_LOG="info,psyche=debug" if [[ "$OTLP_METRICS_URL" == "" ]]; then - cargo run --release --bin psyche-solana-client -- \ + HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ train \ --wallet-private-key-path ${WALLET_FILE} \ --rpc ${RPC} \ @@ -47,9 +55,10 @@ if [[ "$OTLP_METRICS_URL" == "" ]]; then --micro-batch-size ${BATCH_SIZE} \ --authorizer ${AUTHORIZER} \ --logs "console" \ + [[ "$CHECKPOINT" == "true" ]] && echo "--skip-checkpoint-upload" || echo "" \ "$@" else - cargo run --release --bin psyche-solana-client -- \ + HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ train \ --wallet-private-key-path ${WALLET_FILE} \ --rpc ${RPC} \ @@ -62,5 +71,6 @@ else --authorizer ${AUTHORIZER} \ --oltp-metrics-url "http://localhost:4318/v1/metrics" \ --oltp-logs-url "http://localhost:4318/v1/logs" \ + [[ "$CHECKPOINT" == "true" ]] && echo "--skip-checkpoint-upload" || echo "" \ "$@" fi diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index a4ef145f0..c79bdd5c6 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,14 +1,13 @@ use crate::{CheckpointConfig, WandBInfo}; -use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; -use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; use psyche_tui::LogOutput; use std::{path::PathBuf, time::Duration}; +use tracing::info; pub fn read_identity_secret_key( identity_secret_key_path: Option<&PathBuf>, @@ -141,20 +140,8 @@ pub struct TrainArgs { pub prompt_task: bool, /// If provided, every model parameters update will be save in this directory after each epoch. - #[clap(long, env)] - pub checkpoint_dir: Option, - - /// Path to the Hugging Face repository containing model data and configuration. - #[clap(long, env)] - pub hub_repo: Option, - - /// Name of the GCS bucket containing model data and configuration. - #[clap(long, env)] - pub gcs_bucket: Option, - - /// Prefix within the GCS bucket for model data and configuration. - #[clap(long, env)] - pub gcs_prefix: Option, + #[clap(long, env, default_value_os_t = default_checkpoint_dir())] + pub checkpoint_dir: PathBuf, #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -204,6 +191,10 @@ pub struct TrainArgs { #[clap(long, default_value_t = 3, env)] pub keep_steps: u32, + + /// Skip saving and uploading checkpoints (for testing). + #[clap(long, default_value_t = false, env, hide = true)] + pub skip_checkpoint_upload: bool, } impl TrainArgs { @@ -232,74 +223,23 @@ impl TrainArgs { Ok(wandb_info) } - pub fn checkpoint_config(&self) -> Result> { - let hub_read_token = std::env::var("HF_TOKEN").ok(); - - if self.hub_repo.is_some() && self.gcs_bucket.is_some() { - bail!("Use either GCS or HF hub for checkpoint uploads, not both."); - } - - let checkpoint_dir = match &self.checkpoint_dir { - Some(dir) => dir, - None => { - if self.hub_repo.is_some() || self.gcs_bucket.is_some() { - bail!( - "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" - ); - } - return Ok(None); - } - }; - - let upload_info = self.build_upload_info(&hub_read_token)?; + pub fn checkpoint_config(&self) -> Result { + let hub_token = std::env::var("HF_TOKEN").ok(); - if upload_info.is_some() && self.keep_steps == 0 { + if self.keep_steps == 0 { bail!( "keep_steps must be >= 1 for checkpoint uploads (got {})", self.keep_steps ); } - Ok(Some(CheckpointConfig { - checkpoint_dir: checkpoint_dir.clone(), - upload_info, + Ok(CheckpointConfig { + checkpoint_dir: self.checkpoint_dir.clone(), delete_old_steps: self.delete_old_steps, keep_steps: self.keep_steps, - })) - } - - fn build_upload_info(&self, hub_token: &Option) -> Result> { - if let Some(repo) = &self.hub_repo { - return self.build_hub_upload_info(repo, hub_token); - } - - if let Some(bucket) = &self.gcs_bucket { - return self.build_gcs_upload_info(bucket); - } - - Ok(None) - } - - fn build_hub_upload_info( - &self, - repo: &str, - token: &Option, - ) -> Result> { - let token = token.as_ref().ok_or_else(|| { - anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") - })?; - - Ok(Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: repo.to_string(), - hub_token: token.to_string(), - }))) - } - - fn build_gcs_upload_info(&self, bucket: &str) -> Result> { - Ok(Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: bucket.to_string(), - gcs_prefix: self.gcs_prefix.clone(), - }))) + hub_token, + skip_upload: self.skip_checkpoint_upload, + }) } pub fn eval_tasks(&self) -> Result> { @@ -329,6 +269,13 @@ impl TrainArgs { } } +fn default_checkpoint_dir() -> PathBuf { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); + let final_dir = PathBuf::from(home).join(".cache/psyche/local_checkpoints"); + info!("Default checkpoint directory set to {:?}", final_dir); + final_dir +} + pub fn prepare_environment() { psyche_modeling::set_suggested_env_vars(); diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index 6c01868b3..7b378eab8 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -89,7 +89,6 @@ impl + 'sta // From Run let (tx_witness, mut rx_witness) = mpsc::unbounded_channel(); let (tx_health_check, mut rx_health_check) = mpsc::unbounded_channel(); - let (tx_checkpoint, mut rx_checkpoint) = mpsc::unbounded_channel(); let (tx_model, mut rx_model) = mpsc::unbounded_channel(); let (tx_distro_result, mut rx_distro_result) = mpsc::unbounded_channel(); let (tx_request_download, mut rx_request_download) = mpsc::unbounded_channel(); @@ -112,7 +111,6 @@ impl + 'sta metrics: metrics.clone(), tx_witness, tx_health_check, - tx_checkpoint, tx_model, tx_parameters_req, tx_config, @@ -135,7 +133,6 @@ impl + 'sta let mut retry_check_interval = interval(DOWNLOAD_RETRY_CHECK_INTERVAL); let mut opportunistic_witness_interval = interval(OPPROTUNISTIC_WITNESS_INTERVAL); let mut check_connection_interval = interval(CHECK_CONNECTION_INTERVAL); - let mut wait_for_checkpoint = false; let mut last_gossip_connection_time = SystemTime::now(); debug!("Starting client loop"); @@ -143,9 +140,6 @@ impl + 'sta select! { _ = cancel.cancelled() => { info!("Got request to cancel main client loop"); - if run.doing_checkpoint() { - wait_for_checkpoint = true; - } break; } @@ -531,9 +525,6 @@ impl + 'sta Some(health_check) = rx_health_check.recv() => { watcher.backend_mut().send_health_check(health_check).await?; } - Some(checkpoint) = rx_checkpoint.recv() => { - watcher.backend_mut().send_checkpoint(checkpoint).await?; - } Some(model) = rx_model.recv() => { sharable_model.update_parameters(model)?; }, @@ -678,30 +669,6 @@ impl + 'sta let p2p_shutdown = p2p.shutdown(); - if wait_for_checkpoint { - info!("Waiting for all pending checkpoints to finish"); - - // Keep waiting for checkpoints while there are uploads pending - let mut checkpoint_check_interval = interval(Duration::from_secs(10)); - while run.doing_checkpoint() { - tokio::select! { - checkpoint = rx_checkpoint.recv() => { - if let Some(checkpoint) = checkpoint { - info!("Checkpoint upload completed, sending to Solana"); - watcher.backend_mut().send_checkpoint(checkpoint).await?; - } else { - // Channel closed, no more checkpoints coming - break; - } - } - _ = checkpoint_check_interval.tick() => { - } - } - } - - info!("All checkpoints finished, exiting main client loop"); - } - p2p_shutdown .await .map_err(|e| anyhow!("Error shutting down p2p: {}", e)) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index f756343f9..2ab0c79f7 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,10 +1,12 @@ use crate::UploadInfo; use psyche_coordinator::{ - Coordinator, - model::{self}, + CheckpointerSelection, Coordinator, + model::{self, HubRepo, LLM, Model}, }; use psyche_core::NodeIdentity; -use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub}; +use psyche_data_provider::{ + GcsManifestMetadata, GcsUploadInfo, HubUploadInfo, UploadError, upload_to_gcs, upload_to_hub, +}; #[cfg(feature = "python")] use psyche_modeling::CausalLM; use psyche_modeling::{ @@ -14,7 +16,10 @@ use std::{ cmp::Reverse, collections::{BinaryHeap, HashMap}, path::PathBuf, - sync::Arc, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, }; use tch::Tensor; use thiserror::Error; @@ -39,12 +44,14 @@ pub enum CooldownError { #[error("error while checkpointing: {0}")] Checkpoint(#[from] CheckpointError), + + #[error("error in cooldown step: {0}")] + CoordinatorError(#[from] psyche_coordinator::CoordinatorError), } pub struct CooldownStepMetadata { - tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, - checkpoint_info: Option, + checkpoint_info: CheckpointConfig, checkpoint_extra_files: Vec, model_task_runner: ModelTaskRunner, @@ -59,14 +66,12 @@ pub struct CooldownStepMetadata { impl CooldownStepMetadata { pub fn new( - tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, - checkpoint_info: Option, + checkpoint_info: CheckpointConfig, checkpoint_extra_files: Vec, model_task_runner: ModelTaskRunner, ) -> Self { Self { - tx_checkpoint, tx_model, checkpoint_info, checkpoint_extra_files, @@ -130,6 +135,7 @@ impl CooldownStepMetadata { &self, mut trainers: Vec, state: &Coordinator, + client_index: u64, ) -> Result { let Some(mut trainer) = trainers.pop() else { return Err(CooldownError::NoTrainers); @@ -140,59 +146,104 @@ impl CooldownStepMetadata { let epoch = state.progress.epoch as u32; let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); - let tx_checkpoint = self.tx_checkpoint.clone(); + let Model::LLM(LLM { checkpoint, .. }) = state.model; let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); + let checkpointer_selection = CheckpointerSelection::from_coordinator(state, 0)?; + let is_checkpointer = checkpointer_selection + .is_checkpointer(client_index, state.epoch_state.clients.len() as u64); + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let checkpoint_completed = Arc::new(AtomicBool::new(false)); + + let checkpointing_and_evals: JoinHandle> = + tokio::task::spawn({ + let cancellation_token = cancellation_token.clone(); + let checkpoint_completed = checkpoint_completed.clone(); + async move { + info!("Extracting full model..."); + let (variables, trainer) = + tokio::task::spawn_blocking::<_, Result<_, CheckpointError>>(|| { + let variables = trainer.extract()?; + info!("Model extracted; {} parameters", variables.len()); + Ok((variables, trainer)) + }) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)??; + + let variables_clone: HashMap = variables + .iter() + .map(|(name, var)| (name.clone(), var.shallow_clone())) + .collect(); + + // for p2p model sharing we use the native trainer shape + tx_model + .send(variables_clone) + .map_err(|_| CheckpointError::SendCheckpoint)?; + + // convert from internal shape to serialized shape (e.g. torchtitan to hf) + let (variables, trainer) = match trainer { + #[cfg(feature = "python")] + Trainer::PythonDistributed(_) => { + info!("Converting distributed trainer variables for checkpointing..."); + tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)? + } + _ => (variables, trainer), + }; + + trainers.push(trainer); + let evals = model_task_runner.start(trainers); + if !is_checkpointer { + info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); + return Ok(evals); + } - let checkpointing_and_evals: CheckpointAndEvalsHandle = tokio::task::spawn( - async move { - info!("Extracting full model..."); - let (variables, trainer) = - tokio::task::spawn_blocking::<_, Result<_, CheckpointError>>(|| { - let variables = trainer.extract()?; - info!("Model extracted; {} parameters", variables.len()); - Ok((variables, trainer)) - }) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)??; - - let variables_clone: HashMap = variables - .iter() - .map(|(name, var)| (name.clone(), var.shallow_clone())) - .collect(); - - // for p2p model sharing we use the native trainer shape - tx_model - .send(variables_clone) - .map_err(|_| CheckpointError::SendCheckpoint)?; - - // convert from internal shape to serialized shape (e.g. torchtitan to hf) - let (variables, trainer) = match trainer { - #[cfg(feature = "python")] - Trainer::PythonDistributed(_) => { - info!("Converting distributed trainer variables for checkpointing..."); - tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)? + let CheckpointConfig { + checkpoint_dir, + delete_old_steps, + keep_steps, + hub_token, + skip_upload, + } = checkpoint_info; + + // When skip_upload is true (testing), skip all checkpoint saving + if skip_upload { + info!("Skipping checkpoint save and upload (skip_upload flag is set)"); + checkpoint_completed.store(true, Ordering::SeqCst); + return Ok(evals); } - _ => (variables, trainer), - }; - - trainers.push(trainer); - let evals = model_task_runner.start(trainers); - - let Some(CheckpointConfig { - upload_info, - checkpoint_dir, - delete_old_steps, - keep_steps, - }) = checkpoint_info - else { - return Ok((evals, None)); - }; - - let upload_handle = tokio::task::spawn(async move { + + let upload_info = match checkpoint { + model::Checkpoint::Hub(HubRepo { + repo_id, + revision: _, + }) + | model::Checkpoint::P2P(HubRepo { + repo_id, + revision: _, + }) => { + if let Some(token) = hub_token { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (&repo_id).into(), + hub_token: token, + })) + } else { + warn!("HF_TOKEN env not provided, skipping upload to HuggingFace Hub"); + None + } + } + model::Checkpoint::Gcs(model::GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (&bucket).into(), + gcs_prefix: prefix.as_ref().map(|p| p.into()), + })) + } + _ => None, + }; + let path = checkpoint_dir.join(format!("{run_id}-step{step}")); let local = save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; @@ -202,14 +253,16 @@ impl CooldownStepMetadata { epoch, run_id: run_id.clone(), }; - upload_checkpoint( - upload_info, - manifest_metadata, - local.clone(), - step as u64, - tx_checkpoint, - ) - .await?; + let result = upload_checkpoint(upload_info, manifest_metadata, local.clone(), step as u64, cancellation_token.clone()) + .await; + if let Err(err) = result { + error!("Error uploading checkpoint: {}", err); + } else { + checkpoint_completed.store(true, Ordering::SeqCst); + } + } else { + // No upload configured, but local save succeeded + checkpoint_completed.store(true, Ordering::SeqCst); } cleanup_dirs( @@ -222,16 +275,15 @@ impl CooldownStepMetadata { ) .await; - Ok(()) - }); - - Ok((evals, Some(upload_handle))) - } - .instrument(info_span!("checkpointing")), - ); + Ok(evals) + } + .instrument(info_span!("checkpointing")) + }); Ok(CooldownStep { checkpointing_and_evals, + cancellation_token, + checkpoint_completed, }) } } @@ -265,50 +317,50 @@ async fn upload_checkpoint( manifest_metadata: GcsManifestMetadata, local: Vec, step: u64, - tx_checkpoint: mpsc::UnboundedSender, + cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), CheckpointError> { match upload_info { UploadInfo::Gcs(gcs_info) => { - upload_to_gcs(gcs_info, manifest_metadata, local, step, tx_checkpoint) + upload_to_gcs(gcs_info, manifest_metadata, local, step, cancellation_token) .await .map_err(CheckpointError::UploadError) } - UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, tx_checkpoint) + UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, cancellation_token) .await .map_err(CheckpointError::UploadError), + UploadInfo::Dummy() => { + info!("Dummy upload info provided; skipping upload"); + Ok(()) + } } } -type CheckpointAndEvalsHandle = JoinHandle< - Result< - ( - RunningEvals, - Option>>, - ), - CheckpointError, - >, ->; - #[derive(Debug)] pub struct CooldownStep { - checkpointing_and_evals: CheckpointAndEvalsHandle, + checkpointing_and_evals: JoinHandle>, + cancellation_token: tokio_util::sync::CancellationToken, + checkpoint_completed: Arc, } impl CooldownStep { - pub async fn finish( - self, - ) -> Result< - ( - RunningEvals, - Option>>, - ), - CooldownError, - > { - let (running_evals, upload_handle) = self + pub async fn finish(self) -> Result { + let running_evals = self .checkpointing_and_evals .await .map_err(|_| CooldownError::CheckpointThreadCrashed)??; - Ok((running_evals, upload_handle)) + Ok(running_evals) + } + + pub fn cancel(&self) { + self.cancellation_token.cancel(); + } + + pub fn is_finished(&self) -> bool { + self.checkpointing_and_evals.is_finished() + } + + pub fn checkpoint_complete(&self) -> bool { + self.checkpoint_completed.load(Ordering::SeqCst) } } diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 7a945f74b..d2c3a042e 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -69,7 +69,7 @@ pub struct RunInitConfig { pub write_gradients_dir: Option, // checkpointing - pub checkpoint_config: Option, + pub checkpoint_config: CheckpointConfig, // configurable dummy training time (in seconds) for this client - relevant just for testing pub dummy_training_delay_secs: Option, @@ -155,7 +155,6 @@ pub struct RunInitConfigAndIO { pub tx_health_check: UnboundedSender>, pub tx_witness: UnboundedSender, - pub tx_checkpoint: UnboundedSender, pub tx_model: UnboundedSender>, pub tx_parameters_req: UnboundedSender<(Vec, OneshotModelParameterSender)>, pub tx_config: UnboundedSender<(String, String)>, @@ -177,7 +176,6 @@ impl RunInitConfigAndIO RunInitConfigAndIO::RepoFiles(repo_files), + PretrainedSource::::RepoFiles(repo_files.to_vec()), tokenizer, checkpoint_extra_files, ) @@ -864,7 +862,6 @@ impl RunInitConfigAndIO, sent_warmup_finished: bool, sent_warmup_witness: bool, + sent_cooldown_witness: bool, coordinator_state: Coordinator, - - // Handles for HuggingFace uploads running in background - pending_upload_handles: - Vec>>, } #[derive(Error, Debug)] @@ -165,14 +163,19 @@ impl StepStateMachine Result<(), OpportunisticWitnessError> { - if let Some(committee_info) = &self.current_round.committee_info { + if self.current_round.committee_info.is_some() + && !matches!( + self.coordinator_state.run_state, + RunState::Warmup | RunState::Cooldown + ) + { // trace!("Checking for opprotunistic witness with committee info"); + let committee_info = self.current_round.committee_info.as_ref().unwrap(); if let ActiveStep::Training(step) = &self.active_step { let all_prev_round_batches_are_trained = self .previous_round @@ -338,6 +341,60 @@ impl StepStateMachine StepStateMachine StepStateMachine { let trainers = witnessing.finish().await?.stop_evals().await?; - // check here - self.cleanup_completed_uploads(); - ActiveStep::Cooldown(self.cooldown.start(trainers, &state)?) + ActiveStep::Cooldown(self.cooldown.start(trainers, &state, client_index)?) } // cooldown is done, we consider waiting for members and warmup to be basically the same (ActiveStep::Cooldown(cooldown), RunState::WaitingForMembers) | (ActiveStep::Cooldown(cooldown), RunState::Warmup) | (ActiveStep::Cooldown(cooldown), RunState::Paused) => { - let (trainers, upload_handle) = cooldown.finish().await?; - if let Some(handle) = upload_handle { - self.pending_upload_handles.push(handle); - } + // If we reach state it means at least one of the clients has successfully uploaded the model checkpoint. + // We can cancel any of the other uploads in progress. + cooldown.cancel(); + + let trainers = cooldown.finish().await?; + self.sent_cooldown_witness = false; ActiveStep::Warmup(self.warmup.start( trainers, &mut self.previous_round, @@ -882,11 +936,6 @@ impl StepStateMachine RunManager { } Ok(()) } - - pub fn doing_checkpoint(&self) -> bool { - match &self.0 { - InitStage::Running(step_state_machine) => { - let has_pending_uploads = step_state_machine - .pending_upload_handles - .iter() - .any(|handle| !handle.is_finished()); - - has_pending_uploads - } - _ => false, - } - } } impl From<&RunManager> diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 29734f1a0..085211cb7 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -13,14 +13,29 @@ use tokio::task::JoinHandle; pub enum UploadInfo { Hub(HubUploadInfo), Gcs(GcsUploadInfo), + Dummy(), } #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub upload_info: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, + pub hub_token: Option, + /// Skip saving and uploading checkpoints (for testing). + pub skip_upload: bool, +} + +impl CheckpointConfig { + pub fn dummy() -> Self { + Self { + checkpoint_dir: PathBuf::from("./checkpoints"), + delete_old_steps: false, + keep_steps: 1, + hub_token: None, + skip_upload: false, + } + } } #[derive(Debug)] diff --git a/shared/coordinator/Cargo.toml b/shared/coordinator/Cargo.toml index f7cdecc81..024696555 100644 --- a/shared/coordinator/Cargo.toml +++ b/shared/coordinator/Cargo.toml @@ -9,6 +9,7 @@ async-trait.workspace = true anchor-lang.workspace = true bytemuck.workspace = true serde_with.workspace = true +anyhow.workspace = true serde.workspace = true cfg_eval = "0.1.2" ts-rs.workspace = true diff --git a/shared/coordinator/src/checkpointer.rs b/shared/coordinator/src/checkpointer.rs new file mode 100644 index 000000000..9cbd83383 --- /dev/null +++ b/shared/coordinator/src/checkpointer.rs @@ -0,0 +1,63 @@ +use std::cmp::max; + +use crate::{Coordinator, CoordinatorError, coordinator::SOLANA_MAX_NUM_CHECKPOINTERS}; +use psyche_core::{NodeIdentity, compute_shuffled_index, sha256, sha256v}; + +use super::types::salts; + +#[derive(Clone)] +pub struct CheckpointerSelection { + checkpointers: u64, + seed: [u8; 32], +} + +impl CheckpointerSelection { + pub fn new(checkpointers: u64, seed: [u8; 32]) -> Self { + Self { + checkpointers, + seed, + } + } + + pub fn from_coordinator( + coordinator: &Coordinator, + offset: isize, + ) -> Result { + let round = get_round_by_offset(coordinator, offset)?; + let seed = sha256(&round.random_seed.to_le_bytes()); + + let checkpointers = max( + (coordinator.epoch_state.clients.len() / 3).min(SOLANA_MAX_NUM_CHECKPOINTERS), + 1, + ) as u64; + Ok(Self { + checkpointers, + seed, + }) + } + + pub fn is_checkpointer(&self, client_index: u64, total_clients: u64) -> bool { + let final_seed = compute_salted_seed(&self.seed, salts::COOLDOWN); + let index = compute_shuffled_index(client_index, total_clients, &final_seed); + index < self.checkpointers + } +} + +pub(crate) fn get_round_by_offset( + coordinator: &Coordinator, + offset: isize, +) -> Result<&crate::Round, CoordinatorError> { + match offset { + -2 => coordinator.previous_previous_round(), + -1 => coordinator.previous_round(), + 0 => coordinator.current_round(), + _ => return Err(CoordinatorError::NoActiveRound), + } + .ok_or(CoordinatorError::NoActiveRound) +} + +pub(crate) fn compute_salted_seed(seed: &[u8; 32], salt: &str) -> [u8; 32] { + let mut result = [0u8; 32]; + result.copy_from_slice(&sha256v(&[&sha256(seed), salt.as_bytes()])); + result +} diff --git a/shared/coordinator/src/committee.rs b/shared/coordinator/src/committee.rs new file mode 100644 index 000000000..e7e94dd84 --- /dev/null +++ b/shared/coordinator/src/committee.rs @@ -0,0 +1,169 @@ +use crate::{Client, Coordinator, CoordinatorError, SOLANA_MAX_NUM_WITNESSES}; +use psyche_core::{NodeIdentity, compute_shuffled_index, sha256}; + +use super::checkpointer::get_round_by_offset; +use super::types::{Committee, CommitteeProof, WitnessProof, salts}; + +#[derive(Clone)] +pub struct CommitteeSelection { + pub(crate) tie_breaker_nodes: u64, + pub(crate) verifier_nodes: u64, + pub(crate) total_nodes: u64, + pub(crate) witness_nodes: u64, + pub(crate) seed: [u8; 32], +} + +impl CommitteeSelection { + pub fn new( + tie_breaker_nodes: usize, + witness_nodes: usize, + verification_percent: u8, + total_nodes: usize, + seed: u64, + ) -> Result { + Self::validate_params( + tie_breaker_nodes, + witness_nodes, + verification_percent, + total_nodes, + )?; + + let free_nodes = total_nodes - tie_breaker_nodes; + let verifier_nodes = (free_nodes * verification_percent as usize) / 100; + let seed = sha256(&seed.to_le_bytes()); + + Ok(Self { + tie_breaker_nodes: tie_breaker_nodes as u64, + verifier_nodes: verifier_nodes as u64, + total_nodes: total_nodes as u64, + witness_nodes: witness_nodes as u64, + seed, + }) + } + + fn validate_params( + tie_breaker_nodes: usize, + witness_nodes: usize, + verification_percent: u8, + total_nodes: usize, + ) -> Result<(), CoordinatorError> { + if total_nodes >= u64::MAX as usize { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + if total_nodes < tie_breaker_nodes { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + if witness_nodes != 0 && total_nodes < witness_nodes { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + if verification_percent > 100 { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + Ok(()) + } + + pub fn from_coordinator( + coordinator: &Coordinator, + offset: isize, + ) -> Result { + let round = get_round_by_offset(coordinator, offset)?; + Self::new( + round.tie_breaker_tasks as usize, + coordinator.config.witness_nodes as usize, + coordinator.config.verification_percent, + round.clients_len as usize, + round.random_seed, + ) + } + + pub fn get_witness(&self, index: u64) -> WitnessProof { + let position = self.compute_shuffled_index(index, salts::WITNESS); + let witness = self.is_witness_at_position(position); + WitnessProof { + witness: witness.into(), + position, + index, + } + } + + pub fn verify_witness(&self, proof: &WitnessProof) -> bool { + let position = self.compute_shuffled_index(proof.index, salts::WITNESS); + proof.position == position && proof.witness == self.is_witness_at_position(position).into() + } + + pub fn verify_witness_for_client( + &self, + client_id: &T, + proof: &WitnessProof, + clients: &[Client], + ) -> bool { + Self::verify_client(client_id, proof.index, clients) && self.verify_witness(proof) + } + + fn is_witness_at_position(&self, position: u64) -> bool { + match self.witness_nodes { + 0 => position < SOLANA_MAX_NUM_WITNESSES as u64, + witness_nodes => position < witness_nodes, + } + } + + pub fn get_committee(&self, index: u64) -> CommitteeProof { + let position = self.compute_shuffled_index(index, salts::COMMITTEE); + let committee = self.get_committee_from_position(position); + CommitteeProof { + committee, + position, + index, + } + } + + pub fn get_committee_from_position(&self, position: u64) -> Committee { + if position < self.tie_breaker_nodes { + Committee::TieBreaker + } else if position < self.tie_breaker_nodes + self.verifier_nodes { + Committee::Verifier + } else { + Committee::Trainer + } + } + + pub fn verify_committee(&self, proof: &CommitteeProof) -> bool { + let position = self.compute_shuffled_index(proof.index, salts::COMMITTEE); + proof.position == position && proof.committee == self.get_committee_from_position(position) + } + + pub fn verify_committee_for_client( + &self, + client_id: &T, + proof: &CommitteeProof, + clients: &[Client], + ) -> bool { + Self::verify_client(client_id, proof.index, clients) && self.verify_committee(proof) + } + + fn verify_client(client_id: &T, index: u64, clients: &[Client]) -> bool { + clients.get(index as usize).map(|c| &c.id) == Some(client_id) + } + + fn compute_shuffled_index(&self, index: u64, salt: &str) -> u64 { + let mut seed = [0u8; 32]; + seed.copy_from_slice(&psyche_core::sha256v(&[&self.seed, salt.as_bytes()])); + compute_shuffled_index(index, self.total_nodes, &seed) + } + + pub fn get_seed(&self) -> [u8; 32] { + self.seed + } + + pub fn get_num_tie_breaker_nodes(&self) -> u64 { + self.tie_breaker_nodes + } + + pub fn get_num_verifier_nodes(&self) -> u64 { + self.verifier_nodes + } + + pub fn get_num_trainer_nodes(&self) -> u64 { + self.total_nodes - self.tie_breaker_nodes - self.verifier_nodes + } +} diff --git a/shared/coordinator/src/committee_selection.rs b/shared/coordinator/src/committee_selection.rs deleted file mode 100644 index 1edc3c0f5..000000000 --- a/shared/coordinator/src/committee_selection.rs +++ /dev/null @@ -1,424 +0,0 @@ -use crate::{Client, Coordinator, CoordinatorError, SOLANA_MAX_NUM_WITNESSES}; - -use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; -use bytemuck::Zeroable; -use psyche_core::{NodeIdentity, SmallBoolean, compute_shuffled_index, sha256, sha256v}; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; - -pub const COMMITTEE_SALT: &str = "committee"; -pub const WITNESS_SALT: &str = "witness"; - -#[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, -)] -#[repr(C)] -pub enum Committee { - #[default] - TieBreaker, - Verifier, - Trainer, -} - -#[derive(Clone)] -pub struct CommitteeSelection { - tie_breaker_nodes: u64, - verifier_nodes: u64, - total_nodes: u64, - witness_nodes: u64, - seed: [u8; 32], -} - -#[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, -)] -#[repr(C)] -pub struct CommitteeProof { - pub committee: Committee, - pub position: u64, - pub index: u64, -} - -#[derive( - Clone, - Copy, - Debug, - PartialEq, - Zeroable, - Default, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - TS, -)] -#[repr(C)] -pub struct WitnessProof { - // position in virtual shuffle, as determined by seed - pub position: u64, - // index into epoch_state.clients of sender - pub index: u64, - // assertion of witness membership or non-membership - pub witness: SmallBoolean, -} - -impl CommitteeSelection { - pub fn new( - tie_breaker_nodes: usize, - witness_nodes: usize, - verification_percent: u8, - total_nodes: usize, - seed: u64, - ) -> Result { - if total_nodes >= u64::MAX as usize { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - if total_nodes < tie_breaker_nodes { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - if witness_nodes != 0 && total_nodes < witness_nodes { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - if verification_percent > 100 { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - let free_nodes = total_nodes - tie_breaker_nodes; - let verifier_nodes = (free_nodes * verification_percent as usize) / 100; - - let seed = sha256(&seed.to_le_bytes()); - - Ok(Self { - tie_breaker_nodes: tie_breaker_nodes as u64, - verifier_nodes: verifier_nodes as u64, - total_nodes: total_nodes as u64, - witness_nodes: witness_nodes as u64, - seed, - }) - } - - pub fn from_coordinator( - coordinator: &Coordinator, - offset: isize, - ) -> Result { - let round = match offset { - -2 => coordinator.previous_previous_round(), - -1 => coordinator.previous_round(), - 0 => coordinator.current_round(), - _ => { - return Err(CoordinatorError::NoActiveRound); - } - } - .ok_or(CoordinatorError::NoActiveRound)?; - Self::new( - round.tie_breaker_tasks as usize, - coordinator.config.witness_nodes as usize, - coordinator.config.verification_percent, - round.clients_len as usize, - round.random_seed, - ) - } - - pub fn get_witness(&self, index: u64) -> WitnessProof { - let position = self.compute_shuffled_index(index, WITNESS_SALT); - let witness = self.get_witness_from_position(position); - WitnessProof { - witness: witness.into(), - position, - index, - } - } - - pub fn get_committee(&self, index: u64) -> CommitteeProof { - let position = self.compute_shuffled_index(index, COMMITTEE_SALT); - let committee = self.get_committee_from_position(position); - CommitteeProof { - committee, - position, - index, - } - } - - pub fn get_committee_from_position(&self, committee_position: u64) -> Committee { - if committee_position < self.tie_breaker_nodes { - Committee::TieBreaker - } else if committee_position < self.tie_breaker_nodes + self.verifier_nodes { - Committee::Verifier - } else { - Committee::Trainer - } - } - - fn get_witness_from_position(&self, witness_position: u64) -> bool { - match self.witness_nodes { - 0 => witness_position < SOLANA_MAX_NUM_WITNESSES as u64, - witness_nodes => witness_position < witness_nodes, - } - } - - pub fn verify_committee_for_client( - &self, - client_id: &T, - proof: &CommitteeProof, - clients: &[Client], - ) -> bool { - Self::verify_client(client_id, proof.index, clients) && self.verify_committee(proof) - } - - pub fn verify_witness_for_client( - &self, - client_id: &T, - proof: &WitnessProof, - clients: &[Client], - ) -> bool { - Self::verify_client(client_id, proof.index, clients) && self.verify_witness(proof) - } - - fn verify_client(client_id: &T, index: u64, clients: &[Client]) -> bool { - clients.get(index as usize).map(|c| &c.id) == Some(client_id) - } - - fn verify_committee(&self, proof: &CommitteeProof) -> bool { - let position = self.compute_shuffled_index(proof.index, COMMITTEE_SALT); - proof.position == position && proof.committee == self.get_committee_from_position(position) - } - - fn verify_witness(&self, proof: &WitnessProof) -> bool { - let position = self.compute_shuffled_index(proof.index, WITNESS_SALT); - proof.position == position - && proof.witness == self.get_witness_from_position(position).into() - } - - fn compute_shuffled_index(&self, index: u64, salt: &str) -> u64 { - let mut seed = [0u8; 32]; - seed.copy_from_slice(&sha256v(&[&self.seed, salt.as_bytes()])); - - compute_shuffled_index(index, self.total_nodes, &seed) - } - - pub fn get_seed(&self) -> [u8; 32] { - self.seed - } - - pub fn get_num_tie_breaker_nodes(&self) -> u64 { - self.tie_breaker_nodes - } - - pub fn get_num_verifier_nodes(&self) -> u64 { - self.verifier_nodes - } - - pub fn get_num_trainer_nodes(&self) -> u64 { - self.total_nodes - self.tie_breaker_nodes - self.verifier_nodes - } -} - -impl std::fmt::Display for Committee { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Committee::TieBreaker => write!(f, "Tie breaker"), - Committee::Verifier => write!(f, "Verifier"), - Committee::Trainer => write!(f, "Trainer"), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_new_committee_selection() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - assert_eq!(cs.tie_breaker_nodes, 10); - assert_eq!(cs.witness_nodes, 20); - assert_eq!(cs.verifier_nodes, 27); // (100 - 10) * 30% = 27 - assert_eq!(cs.total_nodes, 100); - } - - #[test] - fn test_get_committee() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - // Test for all possible indexes - for i in 0..100 { - let proof = cs.get_committee(i); - assert!(proof.position < 100); - - // Verify that the committee matches the position - match proof.committee { - Committee::TieBreaker => assert!(proof.position < 10), - Committee::Verifier => assert!(proof.position >= 10 && proof.position < 37), - Committee::Trainer => assert!(proof.position >= 37), - } - } - } - - #[test] - fn test_get_witness() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - // Test for all possible indexes - for i in 0..100 { - let proof = cs.get_witness(i); - assert!(proof.position < 100); - - // Verify that the witness status matches the position - if proof.witness.is_true() { - assert!(proof.position < 20); - } else { - assert!(proof.position >= 20); - } - } - } - - #[test] - fn test_verify_committee() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - for i in 0..100 { - let proof = cs.get_committee(i); - assert!(cs.verify_committee(&proof)); - - // Test with incorrect proof - let incorrect_proof = CommitteeProof { - committee: Committee::Verifier, - position: 99, - index: i, - }; - assert!(!cs.verify_committee(&incorrect_proof)); - } - } - - #[test] - fn test_verify_witness() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - for i in 0..100 { - let proof = cs.get_witness(i); - assert!(cs.verify_witness(&proof)); - - // Test with incorrect proof - let incorrect_proof = WitnessProof { - witness: !proof.witness, - position: 99, - index: i, - }; - assert!(!cs.verify_witness(&incorrect_proof)); - } - } - - #[test] - fn test_committee_distribution() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - let mut tie_breaker_count = 0; - let mut verifier_count = 0; - let mut trainer_count = 0; - - for i in 0..100 { - match cs.get_committee(i).committee { - Committee::TieBreaker => tie_breaker_count += 1, - Committee::Verifier => verifier_count += 1, - Committee::Trainer => trainer_count += 1, - } - } - - assert_eq!(tie_breaker_count, 10); - assert_eq!(verifier_count, 27); - assert_eq!(trainer_count, 63); - } - - #[test] - fn test_witness_distribution() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - let mut witness_count = 0; - - for i in 0..100 { - if cs.get_witness(i).witness.is_true() { - witness_count += 1; - } - } - - assert_eq!(witness_count, 20); - } - - #[test] - fn test_get_num_nodes() { - let cs = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); - assert_eq!(cs.get_num_tie_breaker_nodes(), 10); - assert_eq!(cs.get_num_verifier_nodes(), 18); - assert_eq!(cs.get_num_trainer_nodes(), 72); - } - - #[test] - fn test_seed_consistency() { - let cs1 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); - let cs2 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); - assert_eq!(cs1.get_seed(), cs2.get_seed()); - } - - #[test] - fn test_invalid_total_nodes() { - assert!(CommitteeSelection::new(10, 5, 20, 9, 12345).is_err()); - } - - #[test] - fn test_invalid_committee_selections() { - // verification_percent > 100 - assert!(CommitteeSelection::new(10, 5, 101, 100, 12345).is_err()); - // total_nodes < tie_breaker_nodes - assert!(CommitteeSelection::new(10, 5, 101, 5, 12345).is_err()); - // total_nodes < witness_nodes - assert!(CommitteeSelection::new(10, 50, 101, 11, 12345).is_err()); - // total_nodes >= u64::MAX - assert!(CommitteeSelection::new(10, 50, 101, u64::MAX as usize, 12345).is_err()); - } - - #[test] - fn test_edge_case_all_tie_breakers() { - let cs = CommitteeSelection::new(100, 5, 20, 100, 12345).unwrap(); - for i in 0..100 { - let committee = cs.get_committee(i).committee; - assert_eq!(committee, Committee::TieBreaker); - } - } - - #[test] - fn test_edge_case_no_verifiers() { - let cs = CommitteeSelection::new(10, 5, 0, 100, 12345).unwrap(); - let mut tie_breaker_count = 0; - let mut trainer_count = 0; - for i in 0..100 { - let committee = cs.get_committee(i).committee; - match committee { - Committee::TieBreaker => tie_breaker_count += 1, - Committee::Trainer => trainer_count += 1, - _ => panic!("Unexpected committee type"), - } - } - assert_eq!(tie_breaker_count, 10); - assert_eq!(trainer_count, 90); - } -} diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index c726655e2..66fd77beb 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,5 +1,5 @@ use crate::{ - Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, + CheckpointerSelection, Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, model::{Checkpoint, Model}, }; @@ -14,6 +14,7 @@ pub const SOLANA_MAX_STRING_LEN: usize = 64; pub const SOLANA_MAX_URL_STRING_LEN: usize = 192; pub const SOLANA_MAX_NUM_CLIENTS: usize = 256; pub const SOLANA_MAX_NUM_WITNESSES: usize = 32; +pub const SOLANA_MAX_NUM_CHECKPOINTERS: usize = 16; // run_id must be at most 32 bytes because of PDA constraints pub const SOLANA_RUN_ID_MAX_LEN: usize = 32; @@ -281,6 +282,7 @@ pub struct CoordinatorEpochState { pub start_timestamp: u64, pub first_round: SmallBoolean, pub cold_start_epoch: SmallBoolean, + pub checkpointed: bool, } #[derive( @@ -416,6 +418,7 @@ impl Default for CoordinatorEpochState { start_step: Default::default(), last_step: Default::default(), start_timestamp: Default::default(), + checkpointed: false, } } } @@ -508,6 +511,37 @@ impl Coordinator { Ok(()) } + pub fn cooldown_witness( + &mut self, + from: &T, + witness: Witness, + ) -> std::result::Result<(), CoordinatorError> { + if self.halted() { + return Err(CoordinatorError::Halted); + } + + if !matches!(self.run_state, RunState::Cooldown) { + return Ok(()); + } + + // Verify the sender matches the witness index to prevent spoofing + let index = witness.proof.index as usize; + if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { + return Err(CoordinatorError::InvalidWitness); + } + + let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; + if !checkpointer_selection + .is_checkpointer(witness.proof.index, self.epoch_state.clients.len() as u64) + { + return Err(CoordinatorError::InvalidWitness); + } + + self.epoch_state.checkpointed = true; + + Ok(()) + } + pub fn witness( &mut self, from: &T, @@ -604,16 +638,28 @@ impl Coordinator { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the checkpoint - // with the last checkpointed one. We could instead have a vector of checkpoints to have - // more download options. + if self.halted() { + return Err(CoordinatorError::Halted); + } + + if !matches!(self.run_state, RunState::Cooldown) { + return Err(CoordinatorError::InvalidRunState); + } + + let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; + if !checkpointer_selection + .is_checkpointer(index as u64, self.epoch_state.clients.len() as u64) + { + return Err(CoordinatorError::InvalidWitness); + } + let Model::LLM(llm) = &mut self.model; match (&llm.checkpoint, checkpoint_repo) { // If current is P2P, wrap the new checkpoint in P2P - (Checkpoint::P2P(_), Checkpoint::Hub(hub_repo)) => { + (Checkpoint::P2P(_) | Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { llm.checkpoint = Checkpoint::P2P(hub_repo); } - (Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { + (Checkpoint::P2P(_) | Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); } // If current is Hub, only accept Hub updates @@ -624,16 +670,12 @@ impl Coordinator { (Checkpoint::Gcs(_), Checkpoint::Gcs(gcs_repo)) => { llm.checkpoint = Checkpoint::Gcs(gcs_repo); } - (Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { - llm.checkpoint = Checkpoint::P2P(hub_repo); - } - (Checkpoint::P2P(_), Checkpoint::Gcs(gcs_repo)) => { - llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); - } // Ignore other combinations _ => {} } + self.epoch_state.checkpointed = true; + Ok(()) } @@ -1052,7 +1094,9 @@ impl Coordinator { &mut self, unix_timestamp: u64, ) -> std::result::Result { - if self.check_timeout(unix_timestamp, self.config.cooldown_time) { + if self.check_timeout(unix_timestamp, self.config.cooldown_time) + || self.epoch_state.checkpointed + { let last_round_batch_size = self.get_target_global_batch_size(self.current_round()); self.progress.epoch_start_data_index = self.current_round_unchecked().data_index + last_round_batch_size as u64; diff --git a/shared/coordinator/src/lib.rs b/shared/coordinator/src/lib.rs index bef26863e..e366ee5cf 100644 --- a/shared/coordinator/src/lib.rs +++ b/shared/coordinator/src/lib.rs @@ -1,15 +1,19 @@ #![allow(unexpected_cfgs)] +mod checkpointer; mod commitment; -mod committee_selection; +mod committee; mod coordinator; mod data_selection; pub mod model; +mod types; +#[cfg(test)] +mod tests; + +pub use checkpointer::CheckpointerSelection; pub use commitment::Commitment; -pub use committee_selection::{ - COMMITTEE_SALT, Committee, CommitteeProof, CommitteeSelection, WITNESS_SALT, WitnessProof, -}; +pub use committee::CommitteeSelection; pub use coordinator::{ BLOOM_FALSE_RATE, Client, ClientState, Coordinator, CoordinatorConfig, CoordinatorEpochState, CoordinatorError, CoordinatorProgress, HealthChecks, MAX_TOKENS_TO_SEND, NUM_STORED_ROUNDS, @@ -20,3 +24,4 @@ pub use coordinator::{ pub use data_selection::{ assign_data_for_state, get_batch_ids_for_node, get_batch_ids_for_round, get_data_index_for_step, }; +pub use types::{Committee, CommitteeProof, WitnessProof, salts}; diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 3176f276e..388fc51c4 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use crate::{SOLANA_MAX_STRING_LEN, coordinator::SOLANA_MAX_URL_STRING_LEN}; use anchor_lang::{ @@ -239,6 +241,17 @@ impl HubRepo { } } +impl FromStr for HubRepo { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + Ok(HubRepo { + repo_id: FixedString::from_str_truncated(s), + revision: None, + }) + } +} + #[derive( Clone, Debug, @@ -280,10 +293,10 @@ impl GcsRepo { #[repr(C)] pub enum Checkpoint { Ephemeral, - Dummy(HubRepo), + Dummy(HubRepo), // Used for testing Hub(HubRepo), - P2P(HubRepo), Gcs(GcsRepo), + P2P(HubRepo), P2PGcs(GcsRepo), } @@ -293,12 +306,16 @@ impl std::fmt::Display for Checkpoint { Checkpoint::Dummy(_hub_repo) => write!(f, "Dummy"), Checkpoint::Ephemeral => write!(f, "Ephemeral"), Checkpoint::Hub(hub_repo) => write!(f, "{}", &hub_repo.repo_id), + Checkpoint::Gcs(gcs_repo) => match &gcs_repo.prefix { + Some(prefix) => write!(f, "gs://{}/{}", &gcs_repo.bucket, prefix), + None => write!(f, "gs://{}", &gcs_repo.bucket), + }, Checkpoint::P2P(hub_repo) => { write!(f, "P2P - Hub repo: {}", &hub_repo.repo_id) } - Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => match &gcs_repo.prefix { - Some(prefix) => write!(f, "gs://{}/{}", &gcs_repo.bucket, prefix), - None => write!(f, "gs://{}", &gcs_repo.bucket), + Checkpoint::P2PGcs(gcs_repo) => match &gcs_repo.prefix { + Some(prefix) => write!(f, "P2P - gs://{}/{}", &gcs_repo.bucket, prefix), + None => write!(f, "P2P - gs://{}", &gcs_repo.bucket), }, } } @@ -338,11 +355,9 @@ impl Model { let bad_checkpoint = match llm.checkpoint { Checkpoint::Dummy(_hub_repo) => false, Checkpoint::Ephemeral => true, + Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => true, // P2P is internal state, not configurable Checkpoint::Hub(hub_repo) => hub_repo.repo_id.is_empty(), - Checkpoint::P2P(hub_repo) => hub_repo.repo_id.is_empty(), - Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { - gcs_repo.bucket.is_empty() - } + Checkpoint::Gcs(gcs_repo) => gcs_repo.bucket.is_empty(), }; if bad_checkpoint { diff --git a/shared/coordinator/src/tests.rs b/shared/coordinator/src/tests.rs new file mode 100644 index 000000000..7c7e75548 --- /dev/null +++ b/shared/coordinator/src/tests.rs @@ -0,0 +1,174 @@ +use super::*; + +#[test] +fn test_new_committee_selection() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + assert_eq!(cs.tie_breaker_nodes, 10); + assert_eq!(cs.witness_nodes, 20); + assert_eq!(cs.verifier_nodes, 27); // (100 - 10) * 30% = 27 + assert_eq!(cs.total_nodes, 100); +} + +#[test] +fn test_get_committee() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + // Test for all possible indexes + for i in 0..100 { + let proof = cs.get_committee(i); + assert!(proof.position < 100); + + // Verify that the committee matches the position + match proof.committee { + Committee::TieBreaker => assert!(proof.position < 10), + Committee::Verifier => assert!(proof.position >= 10 && proof.position < 37), + Committee::Trainer => assert!(proof.position >= 37), + } + } +} + +#[test] +fn test_get_witness() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + // Test for all possible indexes + for i in 0..100 { + let proof = cs.get_witness(i); + assert!(proof.position < 100); + + // Verify that the witness status matches the position + if proof.witness.is_true() { + assert!(proof.position < 20); + } else { + assert!(proof.position >= 20); + } + } +} + +#[test] +fn test_verify_committee() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + for i in 0..100 { + let proof = cs.get_committee(i); + assert!(cs.verify_committee(&proof)); + + // Test with incorrect proof + let incorrect_proof = CommitteeProof { + committee: Committee::Verifier, + position: 99, + index: i, + }; + assert!(!cs.verify_committee(&incorrect_proof)); + } +} + +#[test] +fn test_verify_witness() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + for i in 0..100 { + let proof = cs.get_witness(i); + assert!(cs.verify_witness(&proof)); + + // Test with incorrect proof + let incorrect_proof = WitnessProof { + witness: !proof.witness, + position: 99, + index: i, + }; + assert!(!cs.verify_witness(&incorrect_proof)); + } +} + +#[test] +fn test_committee_distribution() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + let mut tie_breaker_count = 0; + let mut verifier_count = 0; + let mut trainer_count = 0; + + for i in 0..100 { + match cs.get_committee(i).committee { + Committee::TieBreaker => tie_breaker_count += 1, + Committee::Verifier => verifier_count += 1, + Committee::Trainer => trainer_count += 1, + } + } + + assert_eq!(tie_breaker_count, 10); + assert_eq!(verifier_count, 27); + assert_eq!(trainer_count, 63); +} + +#[test] +fn test_witness_distribution() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + let mut witness_count = 0; + + for i in 0..100 { + if cs.get_witness(i).witness.is_true() { + witness_count += 1; + } + } + + assert_eq!(witness_count, 20); +} + +#[test] +fn test_get_num_nodes() { + let cs = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); + assert_eq!(cs.get_num_tie_breaker_nodes(), 10); + assert_eq!(cs.get_num_verifier_nodes(), 18); + assert_eq!(cs.get_num_trainer_nodes(), 72); +} + +#[test] +fn test_seed_consistency() { + let cs1 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); + let cs2 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); + assert_eq!(cs1.get_seed(), cs2.get_seed()); +} + +#[test] +fn test_invalid_total_nodes() { + assert!(CommitteeSelection::new(10, 5, 20, 9, 12345).is_err()); +} + +#[test] +fn test_invalid_comittee_selections() { + // verification_percent > 100 + assert!(CommitteeSelection::new(10, 5, 101, 100, 12345).is_err()); + // total_nodes < tie_breaker_nodes + assert!(CommitteeSelection::new(10, 5, 101, 5, 12345).is_err()); + // total_nodes < witness_nodes + assert!(CommitteeSelection::new(10, 50, 101, 11, 12345).is_err()); + // total_nodes >= u64::MAX + assert!(CommitteeSelection::new(10, 50, 101, u64::MAX as usize, 12345).is_err()); +} + +#[test] +fn test_edge_case_all_tie_breakers() { + let cs = CommitteeSelection::new(100, 5, 20, 100, 12345).unwrap(); + for i in 0..100 { + let committee = cs.get_committee(i).committee; + assert_eq!(committee, Committee::TieBreaker); + } +} + +#[test] +fn test_edge_case_no_verifiers() { + let cs = CommitteeSelection::new(10, 5, 0, 100, 12345).unwrap(); + let mut tie_breaker_count = 0; + let mut trainer_count = 0; + for i in 0..100 { + let committee = cs.get_committee(i).committee; + match committee { + Committee::TieBreaker => tie_breaker_count += 1, + Committee::Trainer => trainer_count += 1, + _ => panic!("Unexpected committee type"), + } + } + assert_eq!(tie_breaker_count, 10); + assert_eq!(trainer_count, 90); +} diff --git a/shared/coordinator/src/types.rs b/shared/coordinator/src/types.rs new file mode 100644 index 000000000..ac49815d7 --- /dev/null +++ b/shared/coordinator/src/types.rs @@ -0,0 +1,85 @@ +use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; +use bytemuck::Zeroable; +use psyche_core::SmallBoolean; +use serde::{Deserialize, Serialize}; +use ts_rs::TS; + +/// Salt constants for deterministic shuffling +pub mod salts { + pub const COMMITTEE: &str = "committee"; + pub const WITNESS: &str = "witness"; + pub const COOLDOWN: &str = "cooldown"; +} + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Zeroable, + AnchorDeserialize, + AnchorSerialize, + Serialize, + Deserialize, +)] +#[repr(C)] +pub enum Committee { + #[default] + TieBreaker, + Verifier, + Trainer, +} + +impl std::fmt::Display for Committee { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Committee::TieBreaker => write!(f, "Tie breaker"), + Committee::Verifier => write!(f, "Verifier"), + Committee::Trainer => write!(f, "Trainer"), + } + } +} + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Zeroable, + AnchorDeserialize, + AnchorSerialize, + Serialize, + Deserialize, +)] +#[repr(C)] +pub struct CommitteeProof { + pub committee: Committee, + pub position: u64, + pub index: u64, +} + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Zeroable, + AnchorDeserialize, + AnchorSerialize, + Serialize, + Deserialize, + InitSpace, + TS, +)] +#[repr(C)] +pub struct WitnessProof { + /// Position in virtual shuffle, as determined by seed + pub position: u64, + /// Index into epoch_state.clients of sender + pub index: u64, + /// Assertion of witness membership or non-membership + pub witness: SmallBoolean, +} diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index 7ad69e616..de836ef24 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -25,8 +25,12 @@ serde.workspace = true thiserror.workspace = true postcard.workspace = true bytemuck.workspace = true +google-cloud-storage.workspace = true reqwest = "0.12.12" -google-cloud-storage = "0.24.0" +bytes = "1" +google-cloud-auth = "0.16" +google-cloud-gax = "1.4.0" +urlencoding = "2.1.3" chrono = { version = "0.4", features = ["serde"] } serde_json.workspace = true ts-rs.workspace = true diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 67cc9ab9d..2943cde9f 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use bytemuck::Zeroable; use futures::future::try_join_all; use parquet::data_type::AsBytes; -use psyche_coordinator::{Coordinator, HealthChecks, model}; +use psyche_coordinator::{Coordinator, HealthChecks, model::Checkpoint}; use psyche_core::{BatchId, NodeIdentity}; use psyche_data_provider::{ DataProviderTcpClient, DataProviderTcpServer, LengthKnownDataProvider, TokenizedData, @@ -37,7 +37,7 @@ impl WatcherBackend for DummyBackend { bail!("Data provider does not send health check"); } - async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> anyhow::Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: Checkpoint) -> anyhow::Result<()> { bail!("Data provider does not send checkpoints"); } } diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index b84bc5f9a..9de7bb76b 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -1,3 +1,4 @@ +use hf_hub::api::tokio::CommitError; use std::path::PathBuf; use thiserror::Error; @@ -9,45 +10,36 @@ pub enum UploadError { #[error("file {0} doesn't have a valid utf-8 representation")] InvalidFilename(PathBuf), - #[error("failed to send checkpoint notification")] - SendCheckpoint, - - // Hub-specific errors - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] hf_hub::api::tokio::ApiError), - - #[error("failed to commit files: {0}")] - Commit(#[from] hf_hub::api::tokio::CommitError), - - // GCS-specific errors #[error("GCS authentication failed: {0}")] - GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + GcsAuth(String), - #[error("GCS operation failed: {0}")] - GcsStorage(#[from] google_cloud_storage::http::Error), - - // Common errors #[error("IO error: {0}")] Io(#[from] std::io::Error), + #[error("GCS error: {0}")] + Gcs(String), + + #[error("HuggingFace Hub API error: {0}")] + HubApi(#[from] hf_hub::api::tokio::ApiError), + + #[error("HuggingFace Hub commit error: {0}")] + HubCommit(#[from] CommitError), + #[error("JSON error: {0}")] Json(#[from] serde_json::Error), } #[derive(Error, Debug)] pub enum DownloadError { - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] hf_hub::api::tokio::ApiError), - #[error("GCS authentication failed: {0}")] - GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), - - #[error("GCS operation failed: {0}")] - GcsStorage(#[from] google_cloud_storage::http::Error), + GcsAuth(String), #[error("IO error: {0}")] Io(#[from] std::io::Error), + #[error("GCS error: {0}")] + Gcs(String), + #[error("JSON error: {0}")] Json(#[from] serde_json::Error), } diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 71f29e414..01e9fec08 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,19 +1,11 @@ use crate::errors::{DownloadError, UploadError}; use chrono::{DateTime, Utc}; -use google_cloud_storage::client::{Client, ClientConfig}; -use google_cloud_storage::http::objects::upload::Media; -use google_cloud_storage::http::objects::upload::UploadObjectRequest; -use google_cloud_storage::http::objects::upload::UploadType; -use google_cloud_storage::http::objects::{ - download::Range, get::GetObjectRequest, list::ListObjectsRequest, -}; -use psyche_coordinator::model::{self, GcsRepo}; -use psyche_core::FixedString; +use google_cloud_gax::paginator::ItemPaginator; +use google_cloud_storage::client::{Storage, StorageControl}; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use tokio::runtime::Runtime; -use tokio::sync::mpsc; -use tracing::info; +use tracing::{debug, info}; /// Checkpoint manifest.json uploaded to GCS alongside safetensors files. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -110,48 +102,43 @@ pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, ) -> Result, DownloadError> { - // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous - let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { - info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await? - } else { - info!("Using anonymous GCS client"); - ClientConfig::default().anonymous() - }; - let client = Client::new(config); + // Automatically handles authentication via GOOGLE_APPLICATION_CREDENTIALS + let storage = Storage::builder() + .build() + .await + .map_err(|e| DownloadError::Gcs(e.to_string()))?; + + let storage_control = StorageControl::builder() + .build() + .await + .map_err(|e| DownloadError::Gcs(e.to_string()))?; let manifest_object_path = match prefix { Some(p) => format!("{}/manifest.json", p), None => "manifest.json".to_string(), }; - // Get manifest metadata to obtain generation number - let manifest_metadata = client - .get_object(&GetObjectRequest { - bucket: bucket.to_owned(), - object: manifest_object_path.clone(), - ..Default::default() - }) + // Try to get manifest - first check if it exists + let bucket_resource_name = format!("projects/_/buckets/{}", bucket); + let manifest_result = storage + .read_object(&bucket_resource_name, &manifest_object_path) + .send() .await; - match manifest_metadata { - Ok(object_meta) => { - let manifest_generation = object_meta.generation; - - // Download manifest content - let manifest_data = client - .download_object( - &GetObjectRequest { - bucket: bucket.to_owned(), - object: manifest_object_path, - ..Default::default() - }, - &Range::default(), - ) - .await?; + match manifest_result { + Ok(mut read_response) => { + // Read manifest content + let mut manifest_data = Vec::new(); + while let Some(chunk_result) = read_response.next().await { + let chunk = chunk_result.map_err(|e| DownloadError::Gcs(e.to_string()))?; + manifest_data.extend_from_slice(&chunk); + } let manifest: GcsCheckpointManifest = serde_json::from_slice(&manifest_data)?; + // Use step as generation proxy (1.5.x doesn't expose generation in same way) + let manifest_generation = manifest.metadata.step as i64; + info!( "Found manifest: step {}, epoch {}, generation {}", manifest.metadata.step, manifest.metadata.epoch, manifest_generation @@ -171,12 +158,19 @@ pub async fn download_model_from_gcs_async( cache_dir ); std::fs::create_dir_all(&cache_dir)?; - download_files_from_manifest(&client, bucket, prefix, &cache_dir, &manifest).await? + download_files_from_manifest(&storage, bucket, prefix, &cache_dir, &manifest) + .await? }; // Download config files (json, py) - skips if already cached - let config_files = - download_files_no_manifest(&client, bucket, prefix, &cache_dir, &[".json", ".py"]) - .await?; + let config_files = download_files_no_manifest( + &storage_control, + &storage, + bucket, + prefix, + &cache_dir, + &[".json", ".py"], + ) + .await?; files.extend(config_files); Ok(files) } @@ -185,19 +179,28 @@ pub async fn download_model_from_gcs_async( info!("No manifest found, downloading model without manifest"); let cache_dir = get_cache_dir_no_manifest(bucket, prefix); std::fs::create_dir_all(&cache_dir)?; - download_files_no_manifest(&client, bucket, prefix, &cache_dir, &MODEL_EXTENSIONS).await + download_files_no_manifest( + &storage_control, + &storage, + bucket, + prefix, + &cache_dir, + &MODEL_EXTENSIONS, + ) + .await } } } async fn download_files_from_manifest( - client: &Client, + storage: &Storage, bucket: &str, prefix: Option<&str>, cache_dir: &Path, manifest: &GcsCheckpointManifest, ) -> Result, DownloadError> { let mut downloaded_files = Vec::new(); + let bucket_resource_name = format!("projects/_/buckets/{}", bucket); for file_entry in &manifest.files { let object_name = match prefix { @@ -217,17 +220,17 @@ async fn download_files_from_manifest( bucket, object_name, file_entry.generation ); - let data = client - .download_object( - &GetObjectRequest { - bucket: bucket.to_owned(), - object: object_name, - generation: Some(file_entry.generation), - ..Default::default() - }, - &Range::default(), - ) - .await?; + let mut read_response = storage + .read_object(&bucket_resource_name, &object_name) + .send() + .await + .map_err(|e| DownloadError::Gcs(e.to_string()))?; + + let mut data = Vec::new(); + while let Some(chunk_result) = read_response.next().await { + let chunk = chunk_result.map_err(|e| DownloadError::Gcs(e.to_string()))?; + data.extend_from_slice(&chunk); + } std::fs::write(&local_path, &data)?; info!("Downloaded: {} ({} bytes)", file_entry.filename, data.len()); @@ -240,34 +243,34 @@ async fn download_files_from_manifest( /// Download model files by listing the bucket. Skips files that already exist in cache. /// Used for initial model download (no manifest) and to fetch config files (json, py) after manifest download. async fn download_files_no_manifest( - client: &Client, + storage_control: &StorageControl, + storage: &Storage, bucket: &str, prefix: Option<&str>, cache_dir: &Path, extensions: &[&str], ) -> Result, DownloadError> { let mut all_objects = vec![]; - let mut page_token: Option = None; - - loop { - let results = client - .list_objects(&ListObjectsRequest { - bucket: bucket.to_owned(), - prefix: prefix.map(|s| s.to_owned()), - page_token: page_token.clone(), - ..Default::default() - }) - .await?; - for obj in results.items.iter().flatten() { - if extensions.iter().any(|ext| obj.name.ends_with(ext)) { - all_objects.push(obj.name.clone()); - } - } + let parent_name = format!("projects/_/buckets/{}", bucket); + debug!( + "Listing objects in GCS bucket: {}, parent: {}", + bucket, parent_name + ); + let mut list_request = storage_control.list_objects().set_parent(parent_name); + if let Some(p) = prefix { + list_request = list_request.set_prefix(p.to_string()); + } - match results.next_page_token { - Some(token) => page_token = Some(token), - None => break, + let mut stream = list_request.by_item(); + while let Some(obj) = stream + .next() + .await + .transpose() + .map_err(|e| DownloadError::Gcs(e.to_string()))? + { + if extensions.iter().any(|ext| obj.name.ends_with(ext)) { + all_objects.push(obj.name); } } @@ -293,22 +296,21 @@ async fn download_files_no_manifest( info!("Downloading: gs://{}/{}", bucket, object_name); - let data = client - .download_object( - &GetObjectRequest { - bucket: bucket.to_owned(), - object: object_name.clone(), - ..Default::default() - }, - &Range::default(), - ) - .await?; + let bucket_resource_name = format!("projects/_/buckets/{}", bucket); + let mut read_response = storage + .read_object(&bucket_resource_name, &object_name) + .send() + .await + .map_err(|e| DownloadError::Gcs(e.to_string()))?; + + let mut data = Vec::new(); + while let Some(chunk_result) = read_response.next().await { + let chunk = chunk_result.map_err(|e| DownloadError::Gcs(e.to_string()))?; + data.extend_from_slice(&chunk); + } - // Write to cache std::fs::write(&local_path, &data)?; - info!("Downloaded: {} ({} bytes)", filename, data.len()); - downloaded_files.push(local_path); } @@ -328,120 +330,102 @@ pub async fn upload_to_gcs( manifest_metadata: GcsManifestMetadata, local: Vec, step: u64, - tx_checkpoint: mpsc::UnboundedSender, + cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), UploadError> { - let GcsUploadInfo { - gcs_bucket, - gcs_prefix, - } = gcs_info; - - let GcsManifestMetadata { epoch, run_id } = manifest_metadata; - - info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); - - let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { - info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await? - } else { - info!("Using anonymous GCS client"); - ClientConfig::default().anonymous() - }; - let client = Client::new(config); + let storage = Storage::builder() + .build() + .await + .map_err(|e| UploadError::Gcs(e.to_string()))?; let mut manifest = GcsCheckpointManifest { metadata: ManifestMetadata { timestamp: Utc::now(), - epoch, + epoch: manifest_metadata.epoch, step: step as u32, - run_id, + run_id: manifest_metadata.run_id, }, files: Vec::new(), }; - for path in local { + for path in local + .iter() + .filter(|p| p.extension() == Some("safetensors".as_ref())) + { + if cancellation_token.is_cancelled() { + info!("Upload cancelled before uploading {}", path.display()); + return Ok(()); + } + let file_name = path .file_name() - .ok_or_else(|| UploadError::NotAFile(path.clone()))? - .to_str() + .and_then(|n| n.to_str()) .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; - - // Only upload safetensors files - if !file_name.ends_with(".safetensors") { - continue; - } - - let object_name = match &gcs_prefix { - Some(p) => format!("{}/{}", p, file_name), + let object_name = match &gcs_info.gcs_prefix { + Some(p) => format!("{}/{}", p.trim_end_matches('/'), file_name), None => file_name.to_string(), }; + let bucket_resource_name = format!("projects/_/buckets/{}", gcs_info.gcs_bucket); - let size = std::fs::metadata(&path)?.len(); - let data = tokio::fs::read(&path).await?; - - let upload_type = UploadType::Simple(Media::new(object_name.clone())); - let uploaded = client - .upload_object( - &UploadObjectRequest { - bucket: gcs_bucket.clone(), - ..Default::default() - }, - data, - &upload_type, - ) - .await?; + let data_vec = tokio::fs::read(&path).await?; + let size = data_vec.len() as u64; + let data = bytes::Bytes::from(data_vec); + + let upload_future = storage + .write_object(&bucket_resource_name, &object_name, data) + .send_unbuffered(); + + let uploaded_file = tokio::select! { + biased; + + _ = cancellation_token.cancelled() => { + info!("Upload cancelled during upload of {}", path.display()); + return Ok(()); + } + result = upload_future => { + result.map_err(|e| UploadError::Gcs(e.to_string()))? + } + }; info!( - bucket = gcs_bucket, + bucket = gcs_info.gcs_bucket, object = object_name, - size = uploaded.size, - generation = uploaded.generation, - "Uploaded file to GCS" + size = uploaded_file.size, + "Successfully uploaded file to GCS" ); manifest.files.push(ManifestFileEntry { filename: file_name.to_string(), - generation: uploaded.generation, + generation: uploaded_file.generation, size_bytes: size, }); } // Upload the manifest file - let manifest_path = match &gcs_prefix { + let manifest_path = match &gcs_info.gcs_prefix { Some(p) => format!("{}/manifest.json", p), None => "manifest.json".to_string(), }; let manifest_json = serde_json::to_string_pretty(&manifest)?; + let manifest_bytes = bytes::Bytes::from(manifest_json.into_bytes()); - let upload_type = UploadType::Simple(Media::new(manifest_path.clone())); - client - .upload_object( - &UploadObjectRequest { - bucket: gcs_bucket.clone(), - ..Default::default() - }, - manifest_json.into_bytes(), - &upload_type, - ) - .await?; + let bucket_resource_name = format!("projects/_/buckets/{}", gcs_info.gcs_bucket); + storage + .write_object(&bucket_resource_name, &manifest_path, manifest_bytes) + .send_unbuffered() + .await + .map_err(|e| UploadError::Gcs(e.to_string()))?; info!( - bucket = gcs_bucket, + bucket = gcs_info.gcs_bucket, object = manifest_path, "Uploaded manifest to GCS" ); info!( "Upload to GCS complete at gs://{}/{}", - gcs_bucket, - gcs_prefix.as_deref().unwrap_or("") + gcs_info.gcs_bucket, + gcs_info.gcs_prefix.as_deref().unwrap_or("") ); - tx_checkpoint - .send(model::Checkpoint::Gcs(GcsRepo { - bucket: FixedString::from_str_truncated(&gcs_bucket), - prefix: gcs_prefix.map(|p| FixedString::from_str_truncated(&p)), - })) - .map_err(|_| UploadError::SendCheckpoint)?; - Ok(()) } diff --git a/shared/data-provider/src/http.rs b/shared/data-provider/src/http.rs index 5417f8601..dd13169b4 100644 --- a/shared/data-provider/src/http.rs +++ b/shared/data-provider/src/http.rs @@ -2,7 +2,8 @@ use std::{str::FromStr, time::Duration}; use anyhow::{Context, Result, anyhow, bail}; use futures::future::join_all; -use google_cloud_storage::http::objects::list::ListObjectsRequest; +use google_cloud_gax::paginator::ItemPaginator; +use google_cloud_storage::client::StorageControl; use psyche_coordinator::model::HttpTrainingDataLocation; use psyche_core::{BatchId, Shuffle, TokenSize}; use rand::seq::SliceRandom; @@ -10,7 +11,7 @@ use rand_chacha::ChaCha8Rng; use rand_chacha::rand_core::SeedableRng; use reqwest::IntoUrl; use tokio::task::JoinHandle; -use tracing::{info, trace}; +use tracing::{debug, info, trace}; use crate::{ TokenizedData, @@ -85,50 +86,58 @@ impl FileURLs { Ok(Self(urls_with_sizes)) } - pub async fn from_gcp_bucket(bucket_name: &str, directory: Option) -> Result { - let config = google_cloud_storage::client::ClientConfig::default().anonymous(); - let client = google_cloud_storage::client::Client::new(config); - let mut data_files_matching_directory = { - let mut all_results = vec![]; - // the outer option is if we should continue looping - // the inner option is if we have a "next page token" - let mut next_page_token: Option> = Some(None); - - while let Some(maybe_next_page_token) = next_page_token { - let this_results = client - .list_objects(&ListObjectsRequest { - bucket: bucket_name.to_owned(), - prefix: directory.clone(), - page_token: maybe_next_page_token, - ..Default::default() - }) - .await?; - all_results.extend(this_results.items.iter().flatten().filter_map(|obj| { - let file_ext = obj.name.split('.').next_back()?; - if !DATA_FILE_EXTENSIONS.contains(&file_ext) { - return None; - } - - Some( - obj.media_link - .parse::() - .map(|full_url| (full_url, obj.size as u64)) - .map_err(anyhow::Error::from), - ) - })); - - // if we have a token, Some(Some(String)), - // if not, None - next_page_token = this_results.next_page_token.map(Some) - } - all_results + pub async fn from_gcp_bucket( + bucket_name: &str, + directory: Option, + ) -> anyhow::Result { + debug!( + "http: from_gcp_bucket: bucket_name={}, directory={:?}", + bucket_name, directory + ); + let storage_control = StorageControl::builder().build().await?; + + let mut builder = storage_control + .list_objects() + .set_parent(format!("projects/_/buckets/{}", bucket_name)); + if let Some(p) = directory { + builder = builder.set_prefix(p); } - .into_iter() - .collect::>>()?; - data_files_matching_directory.sort_by(|a, b| a.0.cmp(&b.0)); + let mut items = builder.by_item(); + let mut all_results = vec![]; + + // transpose does Result> -> Option> + while let Some(obj) = items.next().await.transpose()? { + // Only process those files with extensions we care about + let file_ext = obj.name.split('.').next_back().unwrap_or(""); + if !DATA_FILE_EXTENSIONS.contains(&file_ext) { + continue; + } + + let full_url = { + // Transforms spaces, etc. into %20 and other url-friendly encodings + let encoded_name = urlencoding::encode(&obj.name); + + // Just in case we have the whole "projects/_/buckets/bucket-name" prefix remove it + let bucket_name_only = obj + .bucket + .strip_prefix("projects/_/buckets/") + .unwrap_or(&obj.bucket); + + format!("https://www.googleapis.com/storage/v1/b/{bucket_name_only}/o/{encoded_name}?alt=media") + .parse::() + .map_err(anyhow::Error::from)? + }; + debug!( + "Constructed full url: {:?} for object: {} with size {}", + full_url, obj.name, obj.size + ); + all_results.push((full_url, obj.size as u64)); + } - Ok(Self(data_files_matching_directory)) + // We sort here to return in deterministic order + all_results.sort_by(|a, b| a.0.cmp(&b.0)); + Ok(Self(all_results)) } pub async fn from_location(location: &HttpTrainingDataLocation) -> Result { diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index 13a575b84..1cc15f2ab 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,16 +1,10 @@ use crate::errors::UploadError; -use crate::hub::model::HubRepo; +use futures::future::try_join_all; use hf_hub::{ Cache, Repo, RepoType, - api::{ - Siblings, - tokio::{ApiError, UploadSource}, - }, + api::{Siblings, tokio::ApiError}, }; -use psyche_coordinator::model; -use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; -use tokio::sync::mpsc; use tracing::{error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -58,26 +52,22 @@ async fn download_repo_async( .collect::>(); let mut ret: Vec = Vec::new(); for chunk in siblings.chunks(max_concurrent_downloads.unwrap_or(siblings.len())) { - let futures = chunk - .iter() - .map(|x| async { - let start_time = Instant::now(); - tracing::debug!(filename = x.rfilename, "Starting file download from hub"); - let res = api.get(&x.rfilename).await; - if res.is_ok() { - let duration_secs = (Instant::now() - start_time).as_secs_f32(); - tracing::info!( - filename = x.rfilename, - duration_secs = duration_secs, - "Finished downloading file from hub" - ); - } - res - }) - .collect::>(); - for future in futures { - ret.push(future.await?); - } + let futures = chunk.iter().map(|x| async { + let start_time = Instant::now(); + tracing::debug!(filename = x.rfilename, "Starting file download from hub"); + let res = api.get(&x.rfilename).await; + if res.is_ok() { + let duration_secs = (Instant::now() - start_time).as_secs_f32(); + tracing::info!( + filename = x.rfilename, + duration_secs = duration_secs, + "Finished downloading file from hub" + ); + } + res + }); + let chunk_results = try_join_all(futures).await?; + ret.extend(chunk_results); } Ok(ret) } @@ -203,63 +193,72 @@ pub async fn upload_to_hub( hub_info: HubUploadInfo, local: Vec, step: u64, - tx_checkpoint: mpsc::UnboundedSender, + cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), UploadError> { let HubUploadInfo { hub_repo, hub_token, } = hub_info; - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); + if cancellation_token.is_cancelled() { + return Ok(()); + } + + // Collect all safetensors files to upload in a single commit + let files_to_upload: Vec<_> = local + .iter() + .filter(|p| p.extension() == Some("safetensors".as_ref())) + .map(|path| -> Result<_, UploadError> { + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))? + .to_string(); + Ok((path.clone().into(), file_name)) + }) + .collect::, _>>()?; + + if files_to_upload.is_empty() { + info!(repo = hub_repo, "No safetensors files to upload"); + return Ok(()); + } + + let file_names: Vec<_> = files_to_upload + .iter() + .map(|(_, name)| name.clone()) + .collect(); + info!( + repo = hub_repo, + file_count = files_to_upload.len(), + "Uploading checkpoint to HuggingFace" + ); let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_token.clone())) + .with_token(Some(hub_token)) .build()?; let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - let files: Result, _> = local - .into_iter() - .map(|path| { - path.file_name() - .ok_or(UploadError::NotAFile(path.clone())) - .and_then(|name| { - name.to_str() - .ok_or(UploadError::InvalidFilename(path.clone())) - .map(|s| s.to_string()) - }) - .map(|name| (path.into(), name)) - }) - .collect(); - - let files = files?; - - let commit_info = api_repo - .upload_files(files, Some(format!("step {step}")), None, false) - .await - .map_err(|e| { - error!( - repo = hub_repo, - error = ?e, - "Failed to upload files to HuggingFace" - ); - e - })?; + let upload_future = + api_repo.upload_files(files_to_upload, Some(format!("step {step}")), None, false); - let revision = commit_info.oid; + tokio::select! { + biased; - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); + _ = cancellation_token.cancelled() => { + info!(repo = hub_repo, "Upload to HuggingFace cancelled"); + return Ok(()); + } + result = upload_future => { + result.map_err(|e| { + error!(repo = hub_repo, error = ?e, "Failed to upload files"); + e + })?; + } + } - tx_checkpoint - .send(model::Checkpoint::Hub(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - })) - .map_err(|_| UploadError::SendCheckpoint)?; + info!(repo = hub_repo, files = ?file_names, "Upload to HuggingFace complete"); Ok(()) } diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 50bf317bb..f19c6af4d 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use psyche_coordinator::{Coordinator, HealthChecks, Witness, WitnessMetadata, model}; +use psyche_coordinator::{Coordinator, HealthChecks, Witness, WitnessMetadata}; use psyche_core::NodeIdentity; use serde::{Deserialize, Serialize}; @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; pub enum OpportunisticData { WitnessStep(Witness, WitnessMetadata), WarmupStep(Witness), + CooldownStep(Witness), } impl OpportunisticData { @@ -15,6 +16,7 @@ impl OpportunisticData { match self { OpportunisticData::WitnessStep(..) => "witness", OpportunisticData::WarmupStep(..) => "warmup", + OpportunisticData::CooldownStep(..) => "cooldown", } } } @@ -27,5 +29,8 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()>; + async fn send_checkpoint( + &mut self, + checkpoint: psyche_coordinator::model::Checkpoint, + ) -> Result<()>; } diff --git a/tools/rust-tools/run-manager/src/commands/run/update_config.rs b/tools/rust-tools/run-manager/src/commands/run/update_config.rs index 641577307..80abc1fac 100644 --- a/tools/rust-tools/run-manager/src/commands/run/update_config.rs +++ b/tools/rust-tools/run-manager/src/commands/run/update_config.rs @@ -95,6 +95,7 @@ impl Command for CommandUpdateConfig { Checkpoint::P2P(hub_repo) | Checkpoint::Dummy(hub_repo) => { llm.checkpoint = Checkpoint::Hub(hub_repo) } + Checkpoint::P2PGcs(gcs_repo) => llm.checkpoint = Checkpoint::Gcs(gcs_repo), _ => {} } Some(Model::LLM(llm)) diff --git a/tools/rust-tools/run-manager/src/docker/manager.rs b/tools/rust-tools/run-manager/src/docker/manager.rs index cf20f480a..596946585 100644 --- a/tools/rust-tools/run-manager/src/docker/manager.rs +++ b/tools/rust-tools/run-manager/src/docker/manager.rs @@ -2,7 +2,7 @@ use anchor_client::solana_sdk::pubkey::Pubkey; use anyhow::{Context, Result, anyhow, bail}; use std::fs; use std::io::{BufRead, BufReader}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; use tokio::signal; use tracing::{error, info, warn}; @@ -163,8 +163,17 @@ impl RunManager { .arg(&self.env_file); if let Some(dir) = &self.scratch_dir { + let scratch_credentials_path = format!("{dir}/application_default_credentials.json"); + if !Path::new(&scratch_credentials_path).exists() { + bail!("GCS credentials were not found in scratch dir"); + } + cmd.arg("--mount") - .arg(format!("type=bind,src={dir},dst=/scratch")); + .arg(format!("type=bind,src={dir},dst=/scratch")) + .arg("--env") + .arg( + "GOOGLE_APPLICATION_CREDENTIALS=/scratch/application_default_credentials.json", + ); } if let Some(Entrypoint { entrypoint, .. }) = entrypoint { diff --git a/website/backend/src/coordinatorChainLoop.ts b/website/backend/src/coordinatorChainLoop.ts index 21aa04dc2..687e407da 100644 --- a/website/backend/src/coordinatorChainLoop.ts +++ b/website/backend/src/coordinatorChainLoop.ts @@ -342,6 +342,17 @@ export async function startWatchCoordinatorChainLoop( }) break } + case 'cooldown_witness': { + const runPdaAddr = i.accounts[1].toString() + const coordinatorAddr = i.accounts[2].toString() + runUpdates.getAndTouchCurrentRun({ + runPdaAddr, + coordinatorAddr, + decoded, + tx, + }) + break + } case 'update_client_version': { const runPdaAddr = i.accounts[1].toString() const coordinatorAddr = i.accounts[2].toString()