diff --git a/Cargo.lock b/Cargo.lock index 211b781..8d888d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,12 +99,83 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + [[package]] name = "anyhow" version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "arc-swap" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d03449bb8ca2cc2ef70869af31463d1ae5ccc8fa3e334b307203fbf815207e" +dependencies = [ + "rustversion", +] + +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash", +] + [[package]] name = "ark-bls12-377" version = "0.4.0" @@ -579,6 +650,117 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c34dda4df7017c8db52132f0f8a2e0f8161649d15723ed63fc00c82d0f2081a" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "axum-macros", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", +] + +[[package]] +name = "axum-server" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1ab4a3ec9ea8a657c72d99a03a824af695bd0fb5ec639ccbd9cd3543b41a5f9" +dependencies = [ + "arc-swap", + "bytes", + "fs-err 3.2.2", + "http", + "http-body", + "hyper", + "hyper-util", + "pin-project-lite", + "rustls", + "rustls-pemfile", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.76" @@ -591,7 +773,7 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-link", + "windows-link 0.2.1", ] [[package]] @@ -636,6 +818,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43d193de1f7487df1914d3a568b772458861d33f9c54249612cc2893d6915054" dependencies = [ "bitcoin_hashes", + "serde", + "unicode-normalization", ] [[package]] @@ -671,17 +855,32 @@ name = "bittensor-rs" version = "0.1.0" dependencies = [ "anyhow", + "argon2", "ark-serialize 0.4.2", "ark-std 0.4.0", "async-trait", + "axum", + "axum-server", + "base64", + "bip39", + "bytes", "chrono", + "clap", + "comfy-table", + "console", + "crypto_secretbox", + "dialoguer", + "dirs", "futures", "hex", + "http", + "indicatif", "num-traits", "parity-scale-codec", "rand 0.9.2", "rand_chacha 0.3.1", "regex", + "reqwest", "scale-decode", "scale-encode", "scale-info", @@ -691,11 +890,18 @@ dependencies = [ "sp-core", "sp-runtime", "subxt", + "tempfile", "thiserror 2.0.17", "tle", "tokio", + "tower", + "tower-http", "tracing", + "tracing-appender", + "tracing-subscriber", + "uuid", "w3f-bls 0.1.3", + "zeroize", ] [[package]] @@ -819,11 +1025,13 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.44" +version = "1.2.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37521ac7aabe3d13122dc382493e20c9416f299d2ccd5b3a5340a2570cdeb0f3" +checksum = "47b26a0954ae34af09b50f0de26458fa95369a0d478d8236d3f93082b219bd29" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -874,7 +1082,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.2.1", ] [[package]] @@ -888,6 +1096,61 @@ dependencies = [ "zeroize", ] +[[package]] +name = "clap" +version = "4.5.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75ca66430e33a14957acc24c5077b503e7d374151b2b4b3a10c83b4ceb4be0e" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793207c7fa6300a0608d1080b858e5fdbe713cdc1c8db9fb17777d8a13e63df0" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.108", +] + +[[package]] +name = "clap_lex" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" + +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + [[package]] name = "combine" version = "4.6.7" @@ -898,6 +1161,17 @@ dependencies = [ "memchr", ] +[[package]] +name = "comfy-table" +version = "7.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958c5d6ecf1f214b4c2bbbbf6ab9523a864bd136dcf71a7e8904799acfe1ad47" +dependencies = [ + "crossterm", + "unicode-segmentation", + "unicode-width", +] + [[package]] name = "common-path" version = "1.0.0" @@ -913,6 +1187,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -960,6 +1247,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.1" @@ -985,6 +1282,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-queue" version = "0.3.12" @@ -1000,6 +1306,29 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crossterm" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" +dependencies = [ + "bitflags 2.10.0", + "crossterm_winapi", + "document-features", + "parking_lot", + "rustix", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.4" @@ -1039,6 +1368,21 @@ dependencies = [ "subtle", ] +[[package]] +name = "crypto_secretbox" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d6cf87adf719ddf43a805e92c6870a531aedda35ff640442cbaf8674e141e1" +dependencies = [ + "aead", + "cipher", + "generic-array", + "poly1305", + "salsa20", + "subtle", + "zeroize", +] + [[package]] name = "ctr" version = "0.9.2" @@ -1205,6 +1549,19 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror 1.0.69", + "zeroize", +] + [[package]] name = "digest" version = "0.9.0" @@ -1226,6 +1583,27 @@ dependencies = [ "subtle", ] +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -1264,12 +1642,27 @@ dependencies = [ "walkdir", ] +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + [[package]] name = "downcast-rs" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" version = "1.0.20" @@ -1369,6 +1762,21 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "enum-ordinalize" version = "4.3.2" @@ -1440,7 +1848,7 @@ checksum = "e2c470c71d91ecbd179935b24170459e926382eaaa86b590b78814e180d8a8e2" dependencies = [ "blake2", "file-guard", - "fs-err", + "fs-err 2.11.0", "prettyplease", "proc-macro2", "quote", @@ -1481,9 +1889,9 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.4" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" [[package]] name = "fixed-hash" @@ -1509,6 +1917,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -1555,6 +1978,22 @@ dependencies = [ "autocfg", ] +[[package]] +name = "fs-err" +version = "3.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf68cef89750956493a66a10f512b9e58d9db21f2a573c079c0bdf1207a54a7" +dependencies = [ + "autocfg", + "tokio", +] + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -1818,6 +2257,16 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +[[package]] +name = "hdrhistogram" +version = "7.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "byteorder", + "num-traits", +] + [[package]] name = "heck" version = "0.5.0" @@ -1938,6 +2387,39 @@ dependencies = [ "pin-utils", "smallvec", "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", ] [[package]] @@ -1946,13 +2428,24 @@ version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ + "base64", "bytes", + "futures-channel", "futures-core", + "futures-util", "http", "http-body", "hyper", + "ipnet", + "libc", + "percent-encoding", "pin-project-lite", + "socket2", + "system-configuration", "tokio", + "tower-service", + "tracing", + "windows-registry", ] [[package]] @@ -2137,6 +2630,19 @@ dependencies = [ "hashbrown 0.16.0", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "inout" version = "0.1.4" @@ -2155,6 +2661,28 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.10.5" @@ -2247,6 +2775,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.82" @@ -2388,6 +2926,16 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "libredox" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +dependencies = [ + "bitflags 2.10.0", + "libc", +] + [[package]] name = "libsecp256k1" version = "0.7.2" @@ -2448,6 +2996,12 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + [[package]] name = "lock_api" version = "0.4.14" @@ -2481,6 +3035,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "memchr" version = "2.7.6" @@ -2510,6 +3070,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2536,6 +3102,23 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "685a9ac4b61f4e728e1d2c6a7844609c16527aeb5e6c865915c08e619c16410f" +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nodrop" version = "0.1.14" @@ -2631,6 +3214,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.37.3" @@ -2646,18 +3235,68 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "opaque-debug" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", +] + [[package]] name = "openssl-probe" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "parity-bip39" version = "2.0.1" @@ -2726,7 +3365,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-link", + "windows-link 0.2.1", ] [[package]] @@ -2824,6 +3463,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "polkavm-common" version = "0.26.0" @@ -2898,6 +3543,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.4" @@ -3089,6 +3740,17 @@ dependencies = [ "bitflags 2.10.0", ] +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom 0.2.16", + "libredox", + "thiserror 1.0.69", +] + [[package]] name = "ref-cast" version = "1.0.25" @@ -3132,12 +3794,55 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "regex-syntax" -version = "0.8.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" - +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", +] + [[package]] name = "rfc6979" version = "0.4.0" @@ -3208,6 +3913,7 @@ version = "0.23.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a9586e9ee2b4f8fab52a0048ca7334d7024eef48e2cb9407e3497bb7cab7fa7" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", @@ -3226,7 +3932,16 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.5.1", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", ] [[package]] @@ -3244,7 +3959,7 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19787cda76408ec5404443dc8b31795c87cd8fec49762dc75fa727740d34acc1" dependencies = [ - "core-foundation", + "core-foundation 0.10.1", "core-foundation-sys", "jni", "log", @@ -3253,7 +3968,7 @@ dependencies = [ "rustls-native-certs", "rustls-platform-verifier-android", "rustls-webpki", - "security-framework", + "security-framework 3.5.1", "security-framework-sys", "webpki-root-certs 0.26.11", "windows-sys 0.59.0", @@ -3271,6 +3986,7 @@ version = "0.103.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -3294,6 +4010,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "salsa20" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213" +dependencies = [ + "cipher", +] + [[package]] name = "same-file" version = "1.0.6" @@ -3525,6 +4250,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.5.1" @@ -3532,7 +4270,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" dependencies = [ "bitflags 2.10.0", - "core-foundation", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -3617,6 +4355,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_spanned" version = "0.6.9" @@ -3626,6 +4375,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serdect" version = "0.2.0" @@ -3690,6 +4451,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77" + [[package]] name = "shlex" version = "1.3.0" @@ -4474,6 +5241,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + [[package]] name = "synstructure" version = "0.13.2" @@ -4485,12 +5261,46 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tap" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "termcolor" version = "1.4.1" @@ -4670,6 +5480,16 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -4776,6 +5596,58 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "hdrhistogram", + "indexmap", + "pin-project-lite", + "slab", + "sync_wrapper", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags 2.10.0", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.41" @@ -4788,6 +5660,18 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf" +dependencies = [ + "crossbeam-channel", + "thiserror 2.0.17", + "time", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.30" @@ -4820,6 +5704,16 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-serde" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704b1aeb7be0d0a84fc9828cae51dab5970fee5088f83d1dd7ee6f6246fc6ff1" +dependencies = [ + "serde", + "tracing-core", +] + [[package]] name = "tracing-subscriber" version = "0.3.20" @@ -4830,6 +5724,8 @@ dependencies = [ "nu-ansi-term", "once_cell", "regex-automata", + "serde", + "serde_json", "sharded-slab", "smallvec", "thread_local", @@ -4837,6 +5733,7 @@ dependencies = [ "tracing", "tracing-core", "tracing-log", + "tracing-serde", ] [[package]] @@ -4860,6 +5757,12 @@ dependencies = [ "hash-db", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "tuplex" version = "0.1.2" @@ -4923,6 +5826,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -4963,12 +5872,35 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "uuid" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" +dependencies = [ + "getrandom 0.3.4", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "valuable" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -5074,6 +6006,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -5147,6 +6088,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmi" version = "0.40.0" @@ -5274,9 +6228,9 @@ checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link", - "windows-result", - "windows-strings", + "windows-link 0.2.1", + "windows-result 0.4.1", + "windows-strings 0.5.1", ] [[package]] @@ -5301,19 +6255,54 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-registry" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e" +dependencies = [ + "windows-link 0.1.3", + "windows-result 0.3.4", + "windows-strings 0.4.2", +] + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link 0.1.3", +] + [[package]] name = "windows-result" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link", + "windows-link 0.2.1", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link 0.1.3", ] [[package]] @@ -5322,7 +6311,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link", + "windows-link 0.2.1", ] [[package]] @@ -5334,6 +6323,15 @@ dependencies = [ "windows-targets 0.42.2", ] +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -5367,7 +6365,7 @@ version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link", + "windows-link 0.2.1", ] [[package]] @@ -5385,6 +6383,21 @@ dependencies = [ "windows_x86_64_msvc 0.42.2", ] +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -5407,7 +6420,7 @@ version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link", + "windows-link 0.2.1", "windows_aarch64_gnullvm 0.53.1", "windows_aarch64_msvc 0.53.1", "windows_i686_gnu 0.53.1", @@ -5424,6 +6437,12 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -5442,6 +6461,12 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" @@ -5460,6 +6485,12 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -5490,6 +6521,12 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + [[package]] name = "windows_i686_msvc" version = "0.52.6" @@ -5508,6 +6545,12 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" @@ -5526,6 +6569,12 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" @@ -5544,6 +6593,12 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 9dfba8f..835ba47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,8 @@ tokio = { version = "1.48.0", features = ["full"] } rand = "0.9" regex = "1.10" tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } +tracing-appender = "0.2" chrono = { version = "0.4", features = ["serde"] } # CRv4 / Timelock encryption dependencies (same as subtensor) @@ -40,5 +42,39 @@ ark-std = "0.4.0" sha2 = "0.10" rand_chacha = "0.3" +# Wallet module dependencies +bip39 = "2.0" +argon2 = "0.5" +crypto_secretbox = "0.1" +zeroize = { version = "1.6", features = ["derive"] } +base64 = "0.22" +dirs = "5.0" + +# Dendrite HTTP client dependencies +reqwest = { version = "0.12", features = ["json", "stream"] } +http = "1.0" +bytes = "1.5" +uuid = { version = "1.7", features = ["v4"] } + +# Axon HTTP server dependencies +axum = { version = "0.7", features = ["macros"] } +tower = { version = "0.5", features = ["full"] } +tower-http = { version = "0.6", features = ["cors", "trace", "timeout"] } +axum-server = { version = "0.7", features = ["tls-rustls"] } + +# CLI dependencies +clap = { version = "4.5", features = ["derive"] } +dialoguer = "0.11" +indicatif = "0.17" +console = "0.15" +comfy-table = "7.1" + +[dev-dependencies] +tempfile = "3.23" + +[[bin]] +name = "btcli" +path = "src/bin/btcli.rs" + [patch.crates-io] w3f-bls = { git = "https://github.com/opentensor/bls", branch = "fix-no-std" } diff --git a/src/axon/handlers.rs b/src/axon/handlers.rs new file mode 100644 index 0000000..0d8ea10 --- /dev/null +++ b/src/axon/handlers.rs @@ -0,0 +1,445 @@ +//! Request handlers for the Axon HTTP server +//! +//! This module provides the core request handling logic including +//! signature verification, synapse extraction, and response building. + +use crate::dendrite::request::header_names; +use crate::errors::{AxonError, SynapseUnauthorized}; +use crate::types::{Synapse, TerminalInfo}; +use axum::body::Bytes; +use axum::response::{IntoResponse, Response}; +use http::{HeaderMap, HeaderValue, StatusCode}; +use sha2::{Digest, Sha256}; +use sp_core::{sr25519, Pair}; +use std::time::Instant; + +/// Bittensor protocol version +pub const AXON_VERSION: u64 = 100; + +/// Status codes matching the Python SDK +pub mod status_codes { + pub const SUCCESS: i32 = 200; + pub const UNAUTHORIZED: i32 = 401; + pub const FORBIDDEN: i32 = 403; + pub const NOT_FOUND: i32 = 404; + pub const TIMEOUT: i32 = 408; + pub const INTERNAL_ERROR: i32 = 500; + pub const SERVICE_UNAVAILABLE: i32 = 503; +} + +/// Status messages for response headers +pub mod status_messages { + pub const SUCCESS: &str = "Success"; + pub const UNAUTHORIZED: &str = "Signature verification failed"; + pub const FORBIDDEN: &str = "Blacklisted"; + pub const NOT_FOUND: &str = "Synapse not found"; + pub const TIMEOUT: &str = "Request timeout"; + pub const INTERNAL_ERROR: &str = "Internal server error"; + pub const SERVICE_UNAVAILABLE: &str = "Service unavailable"; +} + +/// Verified request information extracted from headers +#[derive(Debug, Clone)] +pub struct VerifiedRequest { + /// The dendrite's hotkey SS58 address + pub dendrite_hotkey: String, + /// Request nonce for replay protection + pub nonce: u64, + /// Signature from the dendrite + pub signature: String, + /// Request UUID + pub uuid: String, + /// Computed body hash + pub body_hash: String, +} + +/// Extract and verify a request from headers +/// +/// # Arguments +/// +/// * `headers` - The HTTP request headers +/// * `body` - The request body bytes +/// * `axon_hotkey` - The axon's hotkey SS58 address for verification +/// +/// # Returns +/// +/// A VerifiedRequest if verification succeeds, or an error +pub fn verify_request( + headers: &HeaderMap, + body: &[u8], + axon_hotkey: &str, +) -> Result { + // Extract required headers + let dendrite_hotkey = get_header_string(headers, header_names::DENDRITE_HOTKEY) + .ok_or_else(|| SynapseUnauthorized { + message: "Missing dendrite hotkey header".to_string(), + hotkey: None, + })?; + + let nonce_str = + get_header_string(headers, header_names::DENDRITE_NONCE).ok_or_else(|| { + SynapseUnauthorized { + message: "Missing dendrite nonce header".to_string(), + hotkey: Some(dendrite_hotkey.clone()), + } + })?; + + let nonce: u64 = nonce_str.parse().map_err(|_| SynapseUnauthorized { + message: format!("Invalid nonce format: {}", nonce_str), + hotkey: Some(dendrite_hotkey.clone()), + })?; + + let signature = + get_header_string(headers, header_names::DENDRITE_SIGNATURE).ok_or_else(|| { + SynapseUnauthorized { + message: "Missing dendrite signature header".to_string(), + hotkey: Some(dendrite_hotkey.clone()), + } + })?; + + let uuid = get_header_string(headers, header_names::DENDRITE_UUID).unwrap_or_default(); + + // Compute body hash + let body_hash = compute_body_hash(body); + + // Verify signature + verify_signature(&dendrite_hotkey, nonce, axon_hotkey, &body_hash, &signature).map_err( + |e| SynapseUnauthorized { + message: e.to_string(), + hotkey: Some(dendrite_hotkey.clone()), + }, + )?; + + Ok(VerifiedRequest { + dendrite_hotkey, + nonce, + signature, + uuid, + body_hash, + }) +} + +/// Verify a request signature +/// +/// The signature format matches the Python SDK: +/// `sign(message = "{nonce}.{dendrite_hotkey}.{axon_hotkey}.{body_hash}")` +/// +/// # Arguments +/// +/// * `dendrite_hotkey` - The dendrite's hotkey SS58 address +/// * `nonce` - The request nonce +/// * `axon_hotkey` - The axon's hotkey SS58 address +/// * `body_hash` - The SHA-256 hash of the request body +/// * `signature` - The hex-encoded signature +/// +/// # Returns +/// +/// Ok(()) if verification succeeds, or an error +pub fn verify_signature( + dendrite_hotkey: &str, + nonce: u64, + axon_hotkey: &str, + body_hash: &str, + signature: &str, +) -> Result<(), AxonError> { + // Decode the signature from hex + let sig_bytes = + hex::decode(signature).map_err(|e| AxonError::new(format!("Invalid signature hex: {}", e)))?; + + if sig_bytes.len() != 64 { + return Err(AxonError::new(format!( + "Invalid signature length: expected 64 bytes, got {}", + sig_bytes.len() + ))); + } + + let mut sig_arr = [0u8; 64]; + sig_arr.copy_from_slice(&sig_bytes); + let sig = sr25519::Signature::from_raw(sig_arr); + + // Decode the dendrite's public key from SS58 + let public = ss58_to_public(dendrite_hotkey) + .map_err(|e| AxonError::new(format!("Invalid dendrite hotkey: {}", e)))?; + + // Create the message to verify + let message = format!("{}.{}.{}.{}", nonce, dendrite_hotkey, axon_hotkey, body_hash); + + // Verify the signature + if sr25519::Pair::verify(&sig, message.as_bytes(), &public) { + Ok(()) + } else { + Err(AxonError::new("Signature verification failed")) + } +} + +/// Decode an SS58 address to a public key +fn ss58_to_public(ss58: &str) -> Result { + use sp_core::crypto::Ss58Codec; + sr25519::Public::from_ss58check(ss58).map_err(|e| format!("{:?}", e)) +} + +/// Compute SHA-256 hash of data and return as hex string +pub fn compute_body_hash(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(data); + hex::encode(hasher.finalize()) +} + +/// Extract a synapse from request headers and body +/// +/// # Arguments +/// +/// * `headers` - The HTTP request headers +/// * `body` - The request body bytes +/// +/// # Returns +/// +/// The extracted Synapse +pub fn extract_synapse(headers: &HeaderMap, body: &[u8]) -> Result { + // Parse body as JSON extra fields + let extra: std::collections::HashMap = if body.is_empty() { + std::collections::HashMap::new() + } else { + serde_json::from_slice(body) + .map_err(|e| AxonError::new(format!("Invalid JSON body: {}", e)))? + }; + + // Build dendrite terminal info from headers + let dendrite = TerminalInfo { + ip: get_header_string(headers, header_names::DENDRITE_IP), + port: get_header_u16(headers, header_names::DENDRITE_PORT), + version: get_header_u64(headers, header_names::DENDRITE_VERSION), + nonce: get_header_u64(headers, header_names::DENDRITE_NONCE), + uuid: get_header_string(headers, header_names::DENDRITE_UUID), + hotkey: get_header_string(headers, header_names::DENDRITE_HOTKEY), + signature: get_header_string(headers, header_names::DENDRITE_SIGNATURE), + ..Default::default() + }; + + Ok(Synapse { + name: get_header_string(headers, header_names::NAME), + timeout: get_header_f64(headers, header_names::TIMEOUT), + total_size: get_header_u64(headers, header_names::TOTAL_SIZE), + header_size: get_header_u64(headers, header_names::HEADER_SIZE), + computed_body_hash: get_header_string(headers, header_names::BODY_HASH), + dendrite: Some(dendrite), + axon: Some(TerminalInfo::default()), + extra, + }) +} + +/// Build response headers for a synapse response +/// +/// # Arguments +/// +/// * `hotkey` - The axon's hotkey SS58 address +/// * `status_code` - The response status code +/// * `status_message` - The response status message +/// * `process_time` - Processing time in seconds +/// +/// # Returns +/// +/// HeaderMap with all required response headers +pub fn build_response_headers( + hotkey: &str, + status_code: i32, + status_message: &str, + process_time: f64, +) -> HeaderMap { + let mut headers = HeaderMap::new(); + + // Nonce for response + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos() as u64) + .unwrap_or(0); + + // Add axon headers + if let Ok(hv) = HeaderValue::from_str(&status_code.to_string()) { + headers.insert(header_names::AXON_STATUS_CODE, hv); + } + if let Ok(hv) = HeaderValue::from_str(status_message) { + headers.insert(header_names::AXON_STATUS_MESSAGE, hv); + } + if let Ok(hv) = HeaderValue::from_str(&format!("{:.6}", process_time)) { + headers.insert(header_names::AXON_PROCESS_TIME, hv); + } + if let Ok(hv) = HeaderValue::from_str(hotkey) { + headers.insert(header_names::AXON_HOTKEY, hv); + } + if let Ok(hv) = HeaderValue::from_str(&AXON_VERSION.to_string()) { + headers.insert(header_names::AXON_VERSION, hv); + } + if let Ok(hv) = HeaderValue::from_str(&nonce.to_string()) { + headers.insert(header_names::AXON_NONCE, hv); + } + + headers +} + +/// Build an error response +/// +/// # Arguments +/// +/// * `hotkey` - The axon's hotkey SS58 address +/// * `status_code` - The HTTP status code +/// * `bt_status_code` - The Bittensor status code +/// * `message` - The error message +/// * `process_time` - Processing time in seconds +/// +/// # Returns +/// +/// An axum Response +pub fn build_error_response( + hotkey: &str, + status_code: StatusCode, + bt_status_code: i32, + message: &str, + process_time: f64, +) -> Response { + let headers = build_response_headers(hotkey, bt_status_code, message, process_time); + (status_code, headers, message.to_string()).into_response() +} + +/// Build a success response with JSON body +/// +/// # Arguments +/// +/// * `hotkey` - The axon's hotkey SS58 address +/// * `body` - The response body +/// * `process_time` - Processing time in seconds +/// +/// # Returns +/// +/// An axum Response +pub fn build_success_response(hotkey: &str, body: Bytes, process_time: f64) -> Response { + let headers = build_response_headers( + hotkey, + status_codes::SUCCESS, + status_messages::SUCCESS, + process_time, + ); + (StatusCode::OK, headers, body).into_response() +} + +// Helper functions for header extraction + +fn get_header_string(headers: &HeaderMap, name: &str) -> Option { + headers + .get(name) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) +} + +fn get_header_u64(headers: &HeaderMap, name: &str) -> Option { + get_header_string(headers, name).and_then(|s| s.parse().ok()) +} + +fn get_header_u16(headers: &HeaderMap, name: &str) -> Option { + get_header_string(headers, name).and_then(|s| s.parse().ok()) +} + +fn get_header_f64(headers: &HeaderMap, name: &str) -> Option { + get_header_string(headers, name).and_then(|s| s.parse().ok()) +} + +/// Handler context for processing requests +#[derive(Clone)] +pub struct HandlerContext { + /// The axon's hotkey + pub hotkey: String, + /// Request start time + pub start_time: Instant, +} + +impl HandlerContext { + /// Create a new handler context + pub fn new(hotkey: impl Into) -> Self { + Self { + hotkey: hotkey.into(), + start_time: Instant::now(), + } + } + + /// Get elapsed time in seconds + pub fn elapsed_secs(&self) -> f64 { + self.start_time.elapsed().as_secs_f64() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_body_hash() { + let data = b"test data"; + let hash = compute_body_hash(data); + // SHA-256 hash should be 64 hex characters + assert_eq!(hash.len(), 64); + // Should be deterministic + assert_eq!(hash, compute_body_hash(data)); + } + + #[test] + fn test_build_response_headers() { + let headers = build_response_headers( + "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY", + 200, + "Success", + 0.123456, + ); + + assert!(headers.contains_key(header_names::AXON_STATUS_CODE)); + assert!(headers.contains_key(header_names::AXON_STATUS_MESSAGE)); + assert!(headers.contains_key(header_names::AXON_PROCESS_TIME)); + assert!(headers.contains_key(header_names::AXON_HOTKEY)); + assert!(headers.contains_key(header_names::AXON_VERSION)); + assert!(headers.contains_key(header_names::AXON_NONCE)); + } + + #[test] + fn test_extract_synapse_empty_body() { + let mut headers = HeaderMap::new(); + headers.insert(header_names::NAME, "TestSynapse".parse().unwrap()); + headers.insert(header_names::TIMEOUT, "12.0".parse().unwrap()); + + let synapse = extract_synapse(&headers, &[]).unwrap(); + + assert_eq!(synapse.name, Some("TestSynapse".to_string())); + assert_eq!(synapse.timeout, Some(12.0)); + assert!(synapse.extra.is_empty()); + } + + #[test] + fn test_extract_synapse_with_body() { + let headers = HeaderMap::new(); + let body = br#"{"key": "value"}"#; + + let synapse = extract_synapse(&headers, body).unwrap(); + + assert!(synapse.extra.contains_key("key")); + assert_eq!( + synapse.extra.get("key"), + Some(&serde_json::json!("value")) + ); + } + + #[test] + fn test_handler_context() { + let ctx = HandlerContext::new("test_hotkey"); + assert_eq!(ctx.hotkey, "test_hotkey"); + // Small sleep to ensure elapsed time is > 0 + std::thread::sleep(std::time::Duration::from_millis(1)); + assert!(ctx.elapsed_secs() > 0.0); + } + + #[test] + fn test_status_codes() { + assert_eq!(status_codes::SUCCESS, 200); + assert_eq!(status_codes::UNAUTHORIZED, 401); + assert_eq!(status_codes::FORBIDDEN, 403); + assert_eq!(status_codes::TIMEOUT, 408); + assert_eq!(status_codes::INTERNAL_ERROR, 500); + } +} diff --git a/src/axon/info.rs b/src/axon/info.rs new file mode 100644 index 0000000..0e140b0 --- /dev/null +++ b/src/axon/info.rs @@ -0,0 +1,172 @@ +//! Axon configuration and info types +//! +//! This module provides configuration structures for the Axon HTTP server +//! and re-exports the on-chain AxonInfo type. + +use serde::{Deserialize, Serialize}; + +/// Re-export the on-chain AxonInfo type +pub use crate::types::axon::AxonInfo; + +/// Axon server configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AxonConfig { + /// The port to listen on + pub port: u16, + /// The IP address to bind to (default: "0.0.0.0") + pub ip: String, + /// External IP address for registration (if different from binding IP) + pub external_ip: Option, + /// External port for registration (if different from listening port) + pub external_port: Option, + /// Maximum number of worker threads + pub max_workers: usize, + /// Maximum concurrent requests to process + pub max_concurrent_requests: usize, + /// Default request timeout in seconds + pub default_timeout_secs: u64, + /// Whether to verify request signatures + pub verify_signatures: bool, + /// Whether to trust X-Forwarded-For and X-Real-IP headers. + /// Only enable this when running behind a trusted reverse proxy. + /// When disabled (default), only the direct connection IP is used for IP blacklisting. + pub trust_proxy_headers: bool, +} + +impl Default for AxonConfig { + fn default() -> Self { + Self { + port: 8091, + ip: "0.0.0.0".to_string(), + external_ip: None, + external_port: None, + max_workers: 10, + max_concurrent_requests: 256, + default_timeout_secs: 12, + verify_signatures: true, + trust_proxy_headers: false, + } + } +} + +impl AxonConfig { + /// Create a new AxonConfig with default settings + pub fn new() -> Self { + Self::default() + } + + /// Set the port + pub fn with_port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Set the binding IP address + pub fn with_ip(mut self, ip: impl Into) -> Self { + self.ip = ip.into(); + self + } + + /// Set the external IP for chain registration + pub fn with_external_ip(mut self, ip: impl Into) -> Self { + self.external_ip = Some(ip.into()); + self + } + + /// Set the external port for chain registration + pub fn with_external_port(mut self, port: u16) -> Self { + self.external_port = Some(port); + self + } + + /// Set the maximum worker threads + pub fn with_max_workers(mut self, workers: usize) -> Self { + self.max_workers = workers; + self + } + + /// Set the maximum concurrent requests + pub fn with_max_concurrent_requests(mut self, max: usize) -> Self { + self.max_concurrent_requests = max; + self + } + + /// Set the default timeout + pub fn with_default_timeout(mut self, timeout_secs: u64) -> Self { + self.default_timeout_secs = timeout_secs; + self + } + + /// Enable or disable signature verification + pub fn with_signature_verification(mut self, enabled: bool) -> Self { + self.verify_signatures = enabled; + self + } + + /// Enable or disable trusting proxy headers (X-Forwarded-For, X-Real-IP). + /// Only enable this when running behind a trusted reverse proxy. + /// When disabled (default), only the direct connection IP is used. + pub fn with_trust_proxy_headers(mut self, enabled: bool) -> Self { + self.trust_proxy_headers = enabled; + self + } + + /// Get the socket address string for binding + pub fn socket_addr(&self) -> String { + format!("{}:{}", self.ip, self.port) + } + + /// Get the external IP to use for chain registration + pub fn get_external_ip(&self) -> &str { + self.external_ip.as_deref().unwrap_or(&self.ip) + } + + /// Get the external port to use for chain registration + pub fn get_external_port(&self) -> u16 { + self.external_port.unwrap_or(self.port) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = AxonConfig::default(); + assert_eq!(config.port, 8091); + assert_eq!(config.ip, "0.0.0.0"); + assert_eq!(config.max_concurrent_requests, 256); + } + + #[test] + fn test_builder_pattern() { + let config = AxonConfig::new() + .with_port(9000) + .with_ip("127.0.0.1") + .with_external_ip("1.2.3.4") + .with_external_port(9001) + .with_max_workers(20); + + assert_eq!(config.port, 9000); + assert_eq!(config.ip, "127.0.0.1"); + assert_eq!(config.external_ip, Some("1.2.3.4".to_string())); + assert_eq!(config.external_port, Some(9001)); + assert_eq!(config.max_workers, 20); + } + + #[test] + fn test_socket_addr() { + let config = AxonConfig::new().with_ip("192.168.1.1").with_port(8080); + assert_eq!(config.socket_addr(), "192.168.1.1:8080"); + } + + #[test] + fn test_external_ip_fallback() { + let config = AxonConfig::new().with_ip("127.0.0.1"); + assert_eq!(config.get_external_ip(), "127.0.0.1"); + + let config_with_external = config.with_external_ip("1.2.3.4"); + assert_eq!(config_with_external.get_external_ip(), "1.2.3.4"); + } +} diff --git a/src/axon/middleware.rs b/src/axon/middleware.rs new file mode 100644 index 0000000..e5d8c5a --- /dev/null +++ b/src/axon/middleware.rs @@ -0,0 +1,409 @@ +//! Middleware for the Axon HTTP server +//! +//! This module provides middleware functions for request processing including: +//! - Blacklist checking +//! - Priority queuing +//! - Signature verification +//! - Request logging + +use crate::axon::handlers::{build_error_response, status_codes, status_messages}; +use crate::axon::server::AxonState; +use crate::dendrite::request::header_names; +use axum::body::Body; +use axum::extract::State; +use axum::http::{Request, StatusCode}; +use axum::middleware::Next; +use axum::response::Response; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; + +/// Blacklist middleware - reject requests from blacklisted hotkeys +/// +/// Checks if the dendrite's hotkey is in the blacklist and rejects +/// the request with a 403 Forbidden status if so. +pub async fn blacklist_middleware( + State(state): State>>, + req: Request, + next: Next, +) -> Response { + let start_time = Instant::now(); + + // Extract dendrite hotkey from headers + let dendrite_hotkey = req + .headers() + .get(header_names::DENDRITE_HOTKEY) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + // Check blacklist + let state_read = state.read().await; + + // Extract IP address from request. + // Only trust proxy headers (X-Forwarded-For, X-Real-IP) if trust_proxy_headers is enabled. + // This prevents IP blacklist bypass via header spoofing when not behind a trusted proxy. + let client_ip = if state_read.trust_proxy_headers { + req.headers() + .get("x-forwarded-for") + .or_else(|| req.headers().get("x-real-ip")) + .and_then(|v| v.to_str().ok()) + .and_then(|s| { + // X-Forwarded-For may contain multiple IPs, take the first (original client) + s.split(',').next().map(|ip| ip.trim().to_string()) + }) + } else { + // When trust_proxy_headers is disabled, we don't use proxy headers. + // The actual client IP would be obtained from the connection itself, + // but that's not available in this middleware context without ConnectInfo. + // For now, return None to avoid trusting spoofable headers. + None + }; + + // Check hotkey blacklist + if let Some(ref hotkey) = dendrite_hotkey { + if state_read.blacklist.contains(hotkey) { + warn!("Blocked blacklisted hotkey: {}", hotkey); + let process_time = start_time.elapsed().as_secs_f64(); + return build_error_response( + &state_read.axon_hotkey, + StatusCode::FORBIDDEN, + status_codes::FORBIDDEN, + status_messages::FORBIDDEN, + process_time, + ); + } + } + + // Check IP blacklist + if let Some(ref ip) = client_ip { + if state_read.ip_blacklist.contains(ip) { + warn!("Blocked blacklisted IP: {}", ip); + let process_time = start_time.elapsed().as_secs_f64(); + return build_error_response( + &state_read.axon_hotkey, + StatusCode::FORBIDDEN, + status_codes::FORBIDDEN, + status_messages::FORBIDDEN, + process_time, + ); + } + } + + // Check custom blacklist function + if let Some(ref blacklist_fn) = state_read.blacklist_fn { + if let Some(ref hotkey) = dendrite_hotkey { + let synapse_name = req + .headers() + .get(header_names::NAME) + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown"); + + if blacklist_fn(hotkey, synapse_name) { + warn!( + "Blocked by custom blacklist function: hotkey={}, synapse={}", + hotkey, synapse_name + ); + let process_time = start_time.elapsed().as_secs_f64(); + return build_error_response( + &state_read.axon_hotkey, + StatusCode::FORBIDDEN, + status_codes::FORBIDDEN, + status_messages::FORBIDDEN, + process_time, + ); + } + } + } + + drop(state_read); + next.run(req).await +} + +/// Priority middleware - track request priority +/// +/// Extracts the priority for this request based on the dendrite's hotkey +/// and adds it to the request extensions for later use. +pub async fn priority_middleware( + State(state): State>>, + mut req: Request, + next: Next, +) -> Response { + // Extract dendrite hotkey from headers + let dendrite_hotkey = req + .headers() + .get(header_names::DENDRITE_HOTKEY) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + // Get priority for this hotkey + let priority = if let Some(ref hotkey) = dendrite_hotkey { + let state_read = state.read().await; + + // Check custom priority function first + if let Some(ref priority_fn) = state_read.priority_fn { + let synapse_name = req + .headers() + .get(header_names::NAME) + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown"); + priority_fn(hotkey, synapse_name) + } else { + // Fall back to priority list + state_read.priority_list.get(hotkey).copied().unwrap_or(0.0) + } + } else { + 0.0 + }; + + // Add priority to request extensions + req.extensions_mut().insert(RequestPriority(priority)); + + next.run(req).await +} + +/// Request priority extension +#[derive(Debug, Clone, Copy)] +pub struct RequestPriority(pub f32); + +/// Verification middleware - verify request signatures +/// +/// Verifies the dendrite's signature on the request if signature +/// verification is enabled in the state. +pub async fn verify_middleware( + State(state): State>>, + req: Request, + next: Next, +) -> Response { + let start_time = Instant::now(); + let state_read = state.read().await; + + // Skip verification if disabled + if !state_read.verify_signatures { + drop(state_read); + return next.run(req).await; + } + + let axon_hotkey = state_read.axon_hotkey.clone(); + + // Check custom verify function first + if let Some(ref verify_fn) = state_read.verify_fn { + let synapse_name = req + .headers() + .get(header_names::NAME) + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown"); + + if !verify_fn(synapse_name) { + debug!( + "Request failed custom verification for synapse: {}", + synapse_name + ); + let process_time = start_time.elapsed().as_secs_f64(); + return build_error_response( + &axon_hotkey, + StatusCode::UNAUTHORIZED, + status_codes::UNAUTHORIZED, + status_messages::UNAUTHORIZED, + process_time, + ); + } + } + + drop(state_read); + + // For full signature verification, we need the body + // This is done in the handler since we need to consume the body + next.run(req).await +} + +/// Logging middleware - log request details +/// +/// Logs incoming requests and their processing time. +pub async fn logging_middleware(req: Request, next: Next) -> Response { + let start_time = Instant::now(); + + let method = req.method().clone(); + let uri = req.uri().clone(); + let synapse_name = req + .headers() + .get(header_names::NAME) + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown") + .to_string(); + let dendrite_hotkey = req + .headers() + .get(header_names::DENDRITE_HOTKEY) + .and_then(|v| v.to_str().ok()) + .unwrap_or("anonymous") + .to_string(); + + debug!( + "Incoming request: {} {} synapse={} from={}", + method, uri, synapse_name, dendrite_hotkey + ); + + let response = next.run(req).await; + + let status = response.status(); + let process_time = start_time.elapsed().as_secs_f64(); + + info!( + "Request completed: {} {} synapse={} from={} status={} time={:.3}s", + method, uri, synapse_name, dendrite_hotkey, status, process_time + ); + + response +} + +/// Request counter middleware - track request counts +/// +/// Increments the request counter in the axon state. +pub async fn counter_middleware( + State(state): State>>, + req: Request, + next: Next, +) -> Response { + // Increment request counter + { + let mut state_write = state.write().await; + state_write.request_count += 1; + state_write.total_requests += 1; + } + + let response = next.run(req).await; + + // Decrement active request counter + { + let mut state_write = state.write().await; + state_write.request_count = state_write.request_count.saturating_sub(1); + } + + response +} + +/// Timeout middleware - enforce request timeouts +/// +/// Extracts the timeout from request headers and enforces it. +/// Uses the default timeout if not specified. +pub async fn timeout_middleware( + State(state): State>>, + req: Request, + next: Next, +) -> Response { + let start_time = Instant::now(); + + // Extract timeout from headers or use default + let timeout_secs = req + .headers() + .get(header_names::TIMEOUT) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .unwrap_or(12.0); + + // Clamp timeout to reasonable bounds: min 1 second, max 5 minutes + let timeout_secs = timeout_secs.clamp(1.0, 300.0); + + let timeout_duration = std::time::Duration::from_secs_f64(timeout_secs); + + // Create a future that completes when the request is done or times out + let response_future = next.run(req); + + match tokio::time::timeout(timeout_duration, response_future).await { + Ok(response) => response, + Err(_) => { + let state_read = state.read().await; + let process_time = start_time.elapsed().as_secs_f64(); + warn!("Request timed out after {:.3}s", process_time); + build_error_response( + &state_read.axon_hotkey, + StatusCode::REQUEST_TIMEOUT, + status_codes::TIMEOUT, + status_messages::TIMEOUT, + process_time, + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::{HashMap, HashSet}; + + fn create_test_state() -> Arc> { + Arc::new(RwLock::new(AxonState { + request_count: 0, + total_requests: 0, + blacklist: HashSet::new(), + ip_blacklist: HashSet::new(), + priority_list: HashMap::new(), + axon_hotkey: "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY".to_string(), + verify_signatures: true, + trust_proxy_headers: false, + blacklist_fn: None, + priority_fn: None, + verify_fn: None, + })) + } + + #[test] + fn test_request_priority() { + let priority = RequestPriority(0.75); + assert_eq!(priority.0, 0.75); + } + + #[tokio::test] + async fn test_state_counter() { + let state = create_test_state(); + + // Simulate increment + { + let mut state_write = state.write().await; + state_write.request_count += 1; + state_write.total_requests += 1; + } + + let state_read = state.read().await; + assert_eq!(state_read.request_count, 1); + assert_eq!(state_read.total_requests, 1); + } + + #[tokio::test] + async fn test_blacklist_check() { + let state = create_test_state(); + + // Add a hotkey to blacklist + { + let mut state_write = state.write().await; + state_write + .blacklist + .insert("blacklisted_hotkey".to_string()); + } + + let state_read = state.read().await; + assert!(state_read.blacklist.contains("blacklisted_hotkey")); + assert!(!state_read.blacklist.contains("allowed_hotkey")); + } + + #[tokio::test] + async fn test_priority_list() { + let state = create_test_state(); + + // Add priorities + { + let mut state_write = state.write().await; + state_write + .priority_list + .insert("high_priority".to_string(), 1.0); + state_write + .priority_list + .insert("low_priority".to_string(), 0.1); + } + + let state_read = state.read().await; + assert_eq!(state_read.priority_list.get("high_priority"), Some(&1.0)); + assert_eq!(state_read.priority_list.get("low_priority"), Some(&0.1)); + assert_eq!(state_read.priority_list.get("unknown"), None); + } +} diff --git a/src/axon/mod.rs b/src/axon/mod.rs new file mode 100644 index 0000000..a09ad2f --- /dev/null +++ b/src/axon/mod.rs @@ -0,0 +1,85 @@ +//! Axon HTTP server module for Bittensor network communication +//! +//! The Axon is an HTTP server that receives requests from Dendrites in the +//! Bittensor network. It handles: +//! +//! - Request signature verification +//! - Blacklist/whitelist enforcement +//! - Priority-based request handling +//! - Custom synapse handlers +//! +//! # Example +//! +//! ```ignore +//! use bittensor_rs::axon::{Axon, AxonConfig}; +//! use bittensor_rs::wallet::Keypair; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! // Create a keypair for the axon +//! let keypair = Keypair::from_uri("//Alice")?; +//! +//! // Configure the axon +//! let config = AxonConfig::new() +//! .with_port(8091) +//! .with_ip("0.0.0.0"); +//! +//! // Create the axon server +//! let mut axon = Axon::new(keypair, config); +//! +//! // Attach a handler for a specific synapse type +//! axon.attach("MyQuery", |synapse| async move { +//! // Process the synapse and return the response +//! let mut response = synapse; +//! response.set_field("result", serde_json::json!("Hello from Axon!")); +//! response +//! }); +//! +//! // Set a custom blacklist function +//! axon.set_blacklist(|hotkey, synapse_name| { +//! // Return true to blacklist, false to allow +//! false +//! }); +//! +//! // Set a custom priority function +//! axon.set_priority(|hotkey, synapse_name| { +//! // Return a priority value (higher = more priority) +//! 1.0 +//! }); +//! +//! // Start serving +//! axon.serve().await?; +//! +//! Ok(()) +//! } +//! ``` +//! +//! # Architecture +//! +//! The Axon uses the following middleware stack (in order): +//! +//! 1. **Logging** - Logs all incoming requests +//! 2. **Blacklist** - Rejects blacklisted hotkeys/IPs +//! 3. **Priority** - Assigns priority to requests +//! 4. **Verify** - Verifies request signatures +//! 5. **Timeout** - Enforces request timeouts +//! 6. **Counter** - Tracks request counts +//! +//! Each synapse type has its own route handler registered via `attach()`. + +pub mod handlers; +pub mod info; +pub mod middleware; +pub mod server; + +pub use handlers::{ + build_error_response, build_response_headers, build_success_response, compute_body_hash, + extract_synapse, status_codes, status_messages, verify_request, verify_signature, + HandlerContext, VerifiedRequest, AXON_VERSION, +}; +pub use info::{AxonConfig, AxonInfo}; +pub use middleware::{ + blacklist_middleware, counter_middleware, logging_middleware, priority_middleware, + timeout_middleware, verify_middleware, RequestPriority, +}; +pub use server::{Axon, AxonState, BlacklistFn, PriorityFn, SynapseHandler, VerifyFn}; diff --git a/src/axon/server.rs b/src/axon/server.rs new file mode 100644 index 0000000..b14ba4c --- /dev/null +++ b/src/axon/server.rs @@ -0,0 +1,703 @@ +//! Axon HTTP server implementation +//! +//! The Axon is an HTTP server that receives requests from Dendrites in the +//! Bittensor network. It handles request verification, routing, and response +//! generation. + +use crate::axon::handlers::{ + build_error_response, build_success_response, extract_synapse, status_codes, verify_request, + AXON_VERSION, +}; +use crate::axon::info::{AxonConfig, AxonInfo}; +use crate::axon::middleware::{ + blacklist_middleware, counter_middleware, logging_middleware, priority_middleware, + timeout_middleware, verify_middleware, +}; +use crate::errors::{AxonConfigError, AxonError}; +use crate::types::Synapse; +use crate::wallet::Keypair; +use axum::body::Bytes; + +use axum::http::{HeaderMap, StatusCode}; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; +use axum::{middleware as axum_middleware, Router}; +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::net::{IpAddr, SocketAddr}; +use std::path::Path; +use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::RwLock; +use tower_http::cors::{Any, CorsLayer}; + +use tower_http::trace::TraceLayer; +use tracing::{error, info}; + +/// Type alias for synapse handler function +pub type SynapseHandler = Arc< + dyn Fn(Synapse) -> Pin + Send>> + Send + Sync, +>; + +/// Type alias for blacklist check function +pub type BlacklistFn = Arc bool + Send + Sync>; + +/// Type alias for priority function +pub type PriorityFn = Arc f32 + Send + Sync>; + +/// Type alias for verify function +pub type VerifyFn = Arc bool + Send + Sync>; + +/// Axon server state +pub struct AxonState { + /// Number of currently active requests + pub request_count: u64, + /// Total requests received since startup + pub total_requests: u64, + /// Set of blacklisted hotkeys + pub blacklist: HashSet, + /// Set of blacklisted IPs + pub ip_blacklist: HashSet, + /// Priority mapping: hotkey -> priority (higher = more priority) + pub priority_list: HashMap, + /// The axon's hotkey SS58 address + pub axon_hotkey: String, + /// Whether to verify request signatures + pub verify_signatures: bool, + /// Whether to trust X-Forwarded-For and X-Real-IP headers. + /// Only enable when running behind a trusted reverse proxy. + pub trust_proxy_headers: bool, + /// Custom blacklist function + pub blacklist_fn: Option, + /// Custom priority function + pub priority_fn: Option, + /// Custom verify function + pub verify_fn: Option, +} + +impl Default for AxonState { + fn default() -> Self { + Self { + request_count: 0, + total_requests: 0, + blacklist: HashSet::new(), + ip_blacklist: HashSet::new(), + priority_list: HashMap::new(), + axon_hotkey: String::new(), + verify_signatures: true, + trust_proxy_headers: false, + blacklist_fn: None, + priority_fn: None, + verify_fn: None, + } + } +} + +/// Axon HTTP server for receiving Bittensor network requests +/// +/// The Axon handles: +/// - Request signature verification +/// - Blacklist/whitelist enforcement +/// - Priority-based request queuing +/// - Custom synapse handlers +/// +/// # Example +/// +/// ```ignore +/// use bittensor_rs::axon::{Axon, AxonConfig}; +/// use bittensor_rs::wallet::Keypair; +/// +/// let keypair = Keypair::from_uri("//Alice").unwrap(); +/// let config = AxonConfig::new().with_port(8091); +/// +/// let mut axon = Axon::new(keypair, config); +/// +/// // Attach a handler for a specific synapse +/// axon.attach("MyQuery", |synapse| async move { +/// // Process the synapse and return response +/// synapse +/// }); +/// +/// // Start serving +/// axon.serve().await?; +/// ``` +pub struct Axon { + /// The keypair for signing responses + keypair: Keypair, + /// Server configuration + config: AxonConfig, + /// Server state (shared across handlers) + state: Arc>, + /// Registered synapse handlers + handlers: HashMap, +} + +impl Axon { + /// Create a new Axon server + /// + /// # Arguments + /// + /// * `keypair` - The hotkey keypair for signing responses + /// * `config` - The server configuration + /// + /// # Returns + /// + /// A new Axon instance + pub fn new(keypair: Keypair, config: AxonConfig) -> Self { + let state = AxonState { + axon_hotkey: keypair.ss58_address().to_string(), + verify_signatures: config.verify_signatures, + trust_proxy_headers: config.trust_proxy_headers, + ..Default::default() + }; + + Self { + keypair, + config, + state: Arc::new(RwLock::new(state)), + handlers: HashMap::new(), + } + } + + /// Attach a synapse handler for a specific route + /// + /// # Arguments + /// + /// * `name` - The synapse name (route path) + /// * `handler` - The async handler function + /// + /// # Returns + /// + /// Mutable reference to self for chaining + /// + /// # Example + /// + /// ```ignore + /// axon.attach("Query", |synapse| async move { + /// // Process synapse + /// synapse + /// }); + /// ``` + pub fn attach(&mut self, name: &str, handler: F) -> &mut Self + where + F: Fn(Synapse) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + let handler = Arc::new(move |synapse: Synapse| { + let fut = handler(synapse); + Box::pin(fut) as Pin + Send>> + }); + self.handlers.insert(name.to_string(), handler); + self + } + + /// Set a custom blacklist check function + /// + /// The function receives (hotkey, synapse_name) and returns true if blacklisted. + /// + /// # Arguments + /// + /// * `f` - The blacklist check function + /// + /// # Returns + /// + /// Mutable reference to self for chaining + pub fn set_blacklist(&mut self, f: F) -> &mut Self + where + F: Fn(&str, &str) -> bool + Send + Sync + 'static, + { + // Use try_write first for immediate update without blocking. + // If the lock is held, spawn a task to update when available. + // This is acceptable since set_blacklist is typically called during setup + // before the server starts handling requests. + let blacklist_fn = Arc::new(f); + if let Ok(mut state_write) = self.state.try_write() { + state_write.blacklist_fn = Some(blacklist_fn); + } else { + // Lock is held, spawn a task to update when available + let state = self.state.clone(); + tokio::spawn(async move { + let mut state_write = state.write().await; + state_write.blacklist_fn = Some(blacklist_fn); + }); + } + self + } + + /// Set a custom priority function + /// + /// The function receives (hotkey, synapse_name) and returns a priority value. + /// Higher values indicate higher priority. + /// + /// # Arguments + /// + /// * `f` - The priority function + /// + /// # Returns + /// + /// Mutable reference to self for chaining + pub fn set_priority(&mut self, f: F) -> &mut Self + where + F: Fn(&str, &str) -> f32 + Send + Sync + 'static, + { + // Use try_write first for immediate update without blocking. + // If the lock is held, spawn a task to update when available. + // This is acceptable since set_priority is typically called during setup + // before the server starts handling requests. + let priority_fn = Arc::new(f); + if let Ok(mut state_write) = self.state.try_write() { + state_write.priority_fn = Some(priority_fn); + } else { + // Lock is held, spawn a task to update when available + let state = self.state.clone(); + tokio::spawn(async move { + let mut state_write = state.write().await; + state_write.priority_fn = Some(priority_fn); + }); + } + self + } + + /// Set a custom verification function + /// + /// The function receives the synapse_name and returns true if verification passes. + /// This is called before signature verification. + /// + /// # Arguments + /// + /// * `f` - The verification function + /// + /// # Returns + /// + /// Mutable reference to self for chaining + pub fn set_verify(&mut self, f: F) -> &mut Self + where + F: Fn(&str) -> bool + Send + Sync + 'static, + { + // Use try_write first for immediate update without blocking. + // If the lock is held, spawn a task to update when available. + // This is acceptable since set_verify is typically called during setup + // before the server starts handling requests. + let verify_fn = Arc::new(f); + if let Ok(mut state_write) = self.state.try_write() { + state_write.verify_fn = Some(verify_fn); + } else { + // Lock is held, spawn a task to update when available + let state = self.state.clone(); + tokio::spawn(async move { + let mut state_write = state.write().await; + state_write.verify_fn = Some(verify_fn); + }); + } + self + } + + /// Add a hotkey to the blacklist + /// + /// # Arguments + /// + /// * `hotkey` - The hotkey SS58 address to blacklist + pub async fn blacklist_hotkey(&self, hotkey: impl Into) { + let mut state_write = self.state.write().await; + state_write.blacklist.insert(hotkey.into()); + } + + /// Remove a hotkey from the blacklist + /// + /// # Arguments + /// + /// * `hotkey` - The hotkey SS58 address to remove + pub async fn unblacklist_hotkey(&self, hotkey: &str) { + let mut state_write = self.state.write().await; + state_write.blacklist.remove(hotkey); + } + + /// Add an IP address to the blacklist + /// + /// # Arguments + /// + /// * `ip` - The IP address to blacklist + pub async fn blacklist_ip(&self, ip: impl Into) { + let mut state_write = self.state.write().await; + state_write.ip_blacklist.insert(ip.into()); + } + + /// Remove an IP address from the blacklist + /// + /// # Arguments + /// + /// * `ip` - The IP address to remove + pub async fn unblacklist_ip(&self, ip: &str) { + let mut state_write = self.state.write().await; + state_write.ip_blacklist.remove(ip); + } + + /// Set the priority for a hotkey + /// + /// # Arguments + /// + /// * `hotkey` - The hotkey SS58 address + /// * `priority` - The priority value (higher = more priority) + pub async fn set_hotkey_priority(&self, hotkey: impl Into, priority: f32) { + let mut state_write = self.state.write().await; + state_write.priority_list.insert(hotkey.into(), priority); + } + + /// Get the current request count + pub async fn request_count(&self) -> u64 { + self.state.read().await.request_count + } + + /// Get the total requests received + pub async fn total_requests(&self) -> u64 { + self.state.read().await.total_requests + } + + /// Get the axon's hotkey SS58 address + pub fn hotkey(&self) -> &str { + self.keypair.ss58_address() + } + + /// Get the configuration + pub fn config(&self) -> &AxonConfig { + &self.config + } + + /// Build the axum Router with all handlers and middleware + fn build_router(&self) -> Router<()> { + let state = self.state.clone(); + let handlers = self.handlers.clone(); + let keypair = self.keypair.clone(); + + // Create base router with state + let mut router: Router>> = Router::new(); + + // Health check endpoint + router = router.route("/health", get(health_handler)); + + // Add synapse handlers + for (name, handler) in handlers.iter() { + let handler = handler.clone(); + let keypair = keypair.clone(); + let state_clone = state.clone(); + + let route_handler = move |headers: HeaderMap, body: Bytes| { + let handler = handler.clone(); + let keypair = keypair.clone(); + let state = state_clone.clone(); + + async move { + handle_synapse_request(state, keypair, headers, body, handler).await + } + }; + + router = router.route(&format!("/{}", name), post(route_handler)); + } + + // Add middleware layers + router + .layer(axum_middleware::from_fn_with_state(state.clone(), counter_middleware)) + .layer(axum_middleware::from_fn_with_state(state.clone(), timeout_middleware)) + .layer(axum_middleware::from_fn_with_state(state.clone(), verify_middleware)) + .layer(axum_middleware::from_fn_with_state(state.clone(), priority_middleware)) + .layer(axum_middleware::from_fn_with_state(state.clone(), blacklist_middleware)) + .layer(axum_middleware::from_fn(logging_middleware)) + .layer(TraceLayer::new_for_http()) + .layer( + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any), + ) + .with_state(state) + } + + /// Start the HTTP server + /// + /// # Returns + /// + /// Ok(()) on successful shutdown, or an error + pub async fn serve(self) -> Result<(), AxonError> { + let addr: SocketAddr = self + .config + .socket_addr() + .parse() + .map_err(|e| AxonError::new(format!("Invalid socket address: {}", e)))?; + + let router = self.build_router(); + + info!( + "Axon server starting on {} (hotkey: {})", + addr, + self.keypair.ss58_address() + ); + + let listener = tokio::net::TcpListener::bind(addr) + .await + .map_err(|e| AxonError::new(format!("Failed to bind to {}: {}", addr, e)))?; + + axum::serve(listener, router) + .await + .map_err(|e| AxonError::new(format!("Server error: {}", e)))?; + + Ok(()) + } + + /// Start the HTTP server with TLS + /// + /// # Arguments + /// + /// * `cert_path` - Path to the TLS certificate file + /// * `key_path` - Path to the TLS private key file + /// + /// # Returns + /// + /// Ok(()) on successful shutdown, or an error + pub async fn serve_tls(self, cert_path: &Path, key_path: &Path) -> Result<(), AxonError> { + use axum_server::tls_rustls::RustlsConfig; + + let addr: SocketAddr = self + .config + .socket_addr() + .parse() + .map_err(|e| AxonError::new(format!("Invalid socket address: {}", e)))?; + + let router = self.build_router(); + + // Load TLS configuration + let tls_config = RustlsConfig::from_pem_file(cert_path, key_path) + .await + .map_err(|e| AxonError::new(format!("Failed to load TLS config: {}", e)))?; + + info!( + "Axon server starting on {} with TLS (hotkey: {})", + addr, + self.keypair.ss58_address() + ); + + axum_server::bind_rustls(addr, tls_config) + .serve(router.into_make_service()) + .await + .map_err(|e| AxonError::new(format!("Server error: {}", e)))?; + + Ok(()) + } + + /// Get the AxonInfo for chain registration + /// + /// This returns the information needed to register the axon on-chain. + /// + /// # Arguments + /// + /// * `block` - The current block number + /// + /// # Returns + /// + /// AxonInfo ready for chain registration + pub fn info(&self, block: u64) -> Result { + let external_ip = self.config.get_external_ip(); + let external_port = self.config.get_external_port(); + + let ip: IpAddr = external_ip.parse().map_err(|e| { + AxonConfigError::with_field( + format!("Invalid IP address '{}': {}", external_ip, e), + "external_ip", + ) + })?; + + let ip_type = match ip { + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 6, + }; + + Ok(AxonInfo { + block, + version: AXON_VERSION as u32, + ip, + port: external_port, + ip_type, + protocol: 4, // TCP + placeholder1: 0, + placeholder2: 0, + }) + } +} + +/// Health check handler +async fn health_handler() -> impl IntoResponse { + (StatusCode::OK, "OK") +} + +/// Handle a synapse request +async fn handle_synapse_request( + state: Arc>, + keypair: Keypair, + headers: HeaderMap, + body: Bytes, + handler: SynapseHandler, +) -> Response { + let start_time = std::time::Instant::now(); + let hotkey = keypair.ss58_address().to_string(); + + // Verify the request signature if enabled + { + let state_read = state.read().await; + if state_read.verify_signatures { + match verify_request(&headers, &body, &hotkey) { + Ok(_verified) => {} + Err(e) => { + let process_time = start_time.elapsed().as_secs_f64(); + return build_error_response( + &hotkey, + StatusCode::UNAUTHORIZED, + status_codes::UNAUTHORIZED, + &e.message, + process_time, + ); + } + } + } + } + + // Extract the synapse from the request + let synapse = match extract_synapse(&headers, &body) { + Ok(s) => s, + Err(e) => { + let process_time = start_time.elapsed().as_secs_f64(); + return build_error_response( + &hotkey, + StatusCode::BAD_REQUEST, + status_codes::INTERNAL_ERROR, + &e.message, + process_time, + ); + } + }; + + // Call the handler + let response_synapse = handler(synapse).await; + + // Serialize the response + let response_body = match serde_json::to_vec(&response_synapse.extra) { + Ok(b) => Bytes::from(b), + Err(e) => { + let process_time = start_time.elapsed().as_secs_f64(); + error!("Failed to serialize response: {}", e); + return build_error_response( + &hotkey, + StatusCode::INTERNAL_SERVER_ERROR, + status_codes::INTERNAL_ERROR, + "Failed to serialize response", + process_time, + ); + } + }; + + let process_time = start_time.elapsed().as_secs_f64(); + build_success_response(&hotkey, response_body, process_time) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_keypair() -> Keypair { + Keypair::from_uri("//Alice").expect("Failed to create test keypair") + } + + #[test] + fn test_axon_new() { + let keypair = create_test_keypair(); + let config = AxonConfig::default(); + let axon = Axon::new(keypair, config); + + assert!(!axon.hotkey().is_empty()); + assert_eq!(axon.config().port, 8091); + } + + #[test] + fn test_axon_attach_handler() { + let keypair = create_test_keypair(); + let config = AxonConfig::default(); + let mut axon = Axon::new(keypair, config); + + axon.attach("TestQuery", |synapse| async move { synapse }); + + assert!(axon.handlers.contains_key("TestQuery")); + } + + #[test] + fn test_axon_info() { + let keypair = create_test_keypair(); + let config = AxonConfig::new() + .with_ip("192.168.1.1") + .with_port(9000) + .with_external_ip("1.2.3.4") + .with_external_port(9001); + let axon = Axon::new(keypair, config); + + let info = axon.info(1000).unwrap(); + + assert_eq!(info.block, 1000); + assert_eq!(info.port, 9001); + assert_eq!(info.ip_type, 4); // IPv4 + } + + #[tokio::test] + async fn test_axon_blacklist() { + let keypair = create_test_keypair(); + let config = AxonConfig::default(); + let axon = Axon::new(keypair, config); + + axon.blacklist_hotkey("test_hotkey").await; + + let state = axon.state.read().await; + assert!(state.blacklist.contains("test_hotkey")); + } + + #[tokio::test] + async fn test_axon_priority() { + let keypair = create_test_keypair(); + let config = AxonConfig::default(); + let axon = Axon::new(keypair, config); + + axon.set_hotkey_priority("high_priority", 1.0).await; + axon.set_hotkey_priority("low_priority", 0.1).await; + + let state = axon.state.read().await; + assert_eq!(state.priority_list.get("high_priority"), Some(&1.0)); + assert_eq!(state.priority_list.get("low_priority"), Some(&0.1)); + } + + #[tokio::test] + async fn test_axon_request_counters() { + let keypair = create_test_keypair(); + let config = AxonConfig::default(); + let axon = Axon::new(keypair, config); + + assert_eq!(axon.request_count().await, 0); + assert_eq!(axon.total_requests().await, 0); + + // Simulate some requests + { + let mut state = axon.state.write().await; + state.request_count = 5; + state.total_requests = 100; + } + + assert_eq!(axon.request_count().await, 5); + assert_eq!(axon.total_requests().await, 100); + } + + #[test] + fn test_axon_state_default() { + let state = AxonState::default(); + + assert_eq!(state.request_count, 0); + assert_eq!(state.total_requests, 0); + assert!(state.blacklist.is_empty()); + assert!(state.priority_list.is_empty()); + assert!(state.verify_signatures); + assert!(!state.trust_proxy_headers); + } +} diff --git a/src/bin/btcli.rs b/src/bin/btcli.rs new file mode 100644 index 0000000..5bb3ae8 --- /dev/null +++ b/src/bin/btcli.rs @@ -0,0 +1,21 @@ +//! Bittensor CLI binary entrypoint. +//! +//! This is the main entry point for the btcli command-line tool, +//! providing wallet, stake, subnet, root, and weight management commands. + +use bittensor_rs::cli; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize basic logging for CLI + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::WARN.into()), + ) + .with_target(false) + .init(); + + // Run the CLI + cli::run().await +} diff --git a/src/blocks/listener.rs b/src/blocks/listener.rs index ef08866..a5b4ccc 100644 --- a/src/blocks/listener.rs +++ b/src/blocks/listener.rs @@ -6,7 +6,7 @@ //! - Phase changes (evaluation -> commit -> reveal) use crate::blocks::epoch_tracker::{EpochInfo, EpochPhase, EpochTracker, EpochTransition}; -use crate::chain::{BittensorClient, Error as ChainError}; +use crate::chain::BittensorClient; use futures::StreamExt; use std::sync::Arc; use tokio::sync::{broadcast, RwLock}; diff --git a/src/cli/commands/mod.rs b/src/cli/commands/mod.rs new file mode 100644 index 0000000..6113329 --- /dev/null +++ b/src/cli/commands/mod.rs @@ -0,0 +1,10 @@ +//! CLI command implementations +//! +//! Each module contains the command definitions and execution logic +//! for a specific category of operations. + +pub mod root; +pub mod stake; +pub mod subnet; +pub mod wallet; +pub mod weights; diff --git a/src/cli/commands/root.rs b/src/cli/commands/root.rs new file mode 100644 index 0000000..6b04db7 --- /dev/null +++ b/src/cli/commands/root.rs @@ -0,0 +1,446 @@ +//! Root network commands for managing the root subnet (netuid 0). + +use crate::cli::utils::{ + confirm, create_table_with_headers, format_address, format_tao, keypair_to_signer, + parse_f64_list, parse_u16_list, print_error, print_info, print_success, print_warning, + prompt_password_optional, resolve_endpoint, spinner, +}; +use crate::cli::Cli; +use crate::wallet::Wallet; +use clap::{Args, Subcommand}; + +/// Root network command container +#[derive(Args, Clone)] +pub struct RootCommand { + #[command(subcommand)] + pub command: RootCommands, +} + +/// Available root network operations +#[derive(Subcommand, Clone)] +pub enum RootCommands { + /// Register on the root network (netuid 0) + Register { + /// Wallet name + #[arg(short, long)] + wallet: String, + }, + + /// List all root network validators + List, + + /// Set root network weights + SetWeights { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long, default_value = "default")] + hotkey: String, + /// Network UIDs (comma-separated, e.g., "1,2,3") + #[arg(long)] + netuids: String, + /// Weights (comma-separated, e.g., "0.3,0.5,0.2") + #[arg(long)] + weights: String, + }, + + /// Get root network weights for a validator + GetWeights { + /// Hotkey address (SS58 format) + #[arg(long)] + hotkey: String, + }, + + /// Show root network information + Info, + + /// Show root network delegates + Delegates, +} + +/// Execute root network commands +pub async fn execute(cmd: RootCommand, cli: &Cli) -> anyhow::Result<()> { + match cmd.command { + RootCommands::Register { wallet } => register(&wallet, cli).await, + RootCommands::List => list_root_validators(cli).await, + RootCommands::SetWeights { + wallet, + hotkey, + netuids, + weights, + } => set_weights(&wallet, &hotkey, &netuids, &weights, cli).await, + RootCommands::GetWeights { hotkey } => get_weights(&hotkey, cli).await, + RootCommands::Info => show_info(cli).await, + RootCommands::Delegates => show_delegates(cli).await, + } +} + +/// Register on the root network +async fn register(wallet_name: &str, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::root::root_register; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, "default", None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + + print_info("Registering on root network (subnet 0)"); + print_info(&format!("Coldkey: {}", coldkey.ss58_address())); + print_info(&format!("Hotkey: {}", hotkey.ss58_address())); + + if !confirm("Proceed with root registration?", cli.no_prompt) { + print_info("Registration cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Submitting root registration..."); + let result = root_register(&client, &signer, ExtrinsicWait::Finalized).await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Root registration successful!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Root registration failed: {}", e)); + return Err(anyhow::anyhow!("Registration failed: {}", e)); + } + } + + Ok(()) +} + +/// List all root network validators +async fn list_root_validators(cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::metagraph::sync_metagraph; + + const ROOT_NETUID: u16 = 0; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Fetching root network metagraph..."); + let metagraph = sync_metagraph(&client, ROOT_NETUID) + .await + .map_err(|e| anyhow::anyhow!("Failed to sync root metagraph: {}", e))?; + sp.finish_and_clear(); + + println!("\nRoot Network Validators (Subnet 0)"); + println!("═══════════════════════════════════════════════════════════════"); + + let mut table = create_table_with_headers(&[ + "UID", + "Hotkey", + "Coldkey", + "Stake", + "Trust", + "Consensus", + "Incentive", + ]); + + let n = metagraph.n as usize; + for uid in 0..n as u64 { + if let Some(neuron) = metagraph.neurons.get(&uid) { + table.add_row(vec![ + uid.to_string(), + format_address(&neuron.hotkey.to_string()), + format_address(&neuron.coldkey.to_string()), + format_tao(neuron.total_stake), + format!("{:.4}", neuron.trust), + format!("{:.4}", neuron.consensus), + format!("{:.4}", neuron.incentive), + ]); + } + } + + println!("{table}"); + println!("\nTotal root validators: {}", n); + + Ok(()) +} + +/// Set root network weights +async fn set_weights( + wallet_name: &str, + hotkey_name: &str, + netuids_str: &str, + weights_str: &str, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::root::root_set_weights; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + // Parse netuids + let netuids = parse_u16_list(netuids_str)?; + + // Parse weights + let weight_values = parse_f64_list(weights_str)?; + + if netuids.len() != weight_values.len() { + print_error("Number of netuids must match number of weights"); + return Err(anyhow::anyhow!("Mismatched netuids and weights")); + } + + // Normalize weights to f32 for the API + let sum: f64 = weight_values.iter().sum(); + if sum <= 0.0 { + print_error("Weights must sum to a positive value"); + return Err(anyhow::anyhow!("Invalid weights")); + } + + let normalized: Vec = weight_values + .iter() + .map(|w| (*w / sum) as f32) + .collect(); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let _hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + + print_info("Setting root network weights"); + print_info(&format!("Coldkey: {}", coldkey.ss58_address())); + print_info(&format!("Netuids: {:?}", netuids)); + print_info(&format!("Weights (normalized): {:?}", normalized)); + + if !confirm("Proceed with setting weights?", cli.no_prompt) { + print_info("Weights setting cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + // Convert netuids to u64 for the API + let netuids_u64: Vec = netuids.iter().map(|n| *n as u64).collect(); + + let sp = spinner("Submitting root weights..."); + let result = root_set_weights(&client, &signer, &netuids_u64, &normalized, Some(0), ExtrinsicWait::Finalized).await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Root weights set successfully!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Failed to set weights: {}", e)); + return Err(anyhow::anyhow!("Set weights failed: {}", e)); + } + } + + Ok(()) +} + +/// Get root network weights for a validator +async fn get_weights(hotkey_addr: &str, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::metagraph::sync_metagraph; + + const ROOT_NETUID: u16 = 0; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Fetching root metagraph..."); + let metagraph = sync_metagraph(&client, ROOT_NETUID) + .await + .map_err(|e| anyhow::anyhow!("Failed to fetch metagraph: {}", e))?; + sp.finish_and_clear(); + + println!("\nRoot Network Weights for {}", format_address(hotkey_addr)); + println!("═══════════════════════════════════════════════════"); + + // Find the UID for this hotkey + let mut found_uid: Option = None; + for (uid, neuron) in &metagraph.neurons { + if neuron.hotkey.to_string() == hotkey_addr { + found_uid = Some(*uid); + break; + } + } + + match found_uid { + Some(uid) => { + print_info(&format!("Hotkey found at UID {}", uid)); + // Display root network neuron info + let mut table = create_table_with_headers(&["UID", "Incentive", "Consensus"]); + for uid in 0..metagraph.n { + if let Some(neuron) = metagraph.neurons.get(&uid) { + table.add_row(vec![ + uid.to_string(), + format!("{:.4}", neuron.incentive), + format!("{:.4}", neuron.consensus), + ]); + } + } + println!("{table}"); + } + None => { + print_warning(&format!("Hotkey {} not found in root network", hotkey_addr)); + } + } + + Ok(()) +} + +/// Show root network information +async fn show_info(cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::subnets::{subnet_info, tempo, difficulty, immunity_period}; + + const ROOT_NETUID: u16 = 0; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Fetching root network info..."); + let info = subnet_info(&client, ROOT_NETUID) + .await + .map_err(|e| anyhow::anyhow!("Failed to fetch root info: {}", e))?; + + let tempo_val = tempo(&client, ROOT_NETUID).await.ok().flatten().unwrap_or(0); + let diff_val = difficulty(&client, ROOT_NETUID).await.ok().flatten().unwrap_or(0); + let immunity_val = immunity_period(&client, ROOT_NETUID).await.ok().flatten().unwrap_or(0); + sp.finish_and_clear(); + + match info { + Some(info) => { + println!("\nRoot Network (Subnet 0)"); + println!("═══════════════════════════════════════════════"); + println!( + "Name: {}", + info.name.unwrap_or_else(|| "Root Network".to_string()) + ); + println!("Validators: {}", info.neuron_count); + println!("Tempo: {} blocks", tempo_val); + println!("Difficulty: {}", diff_val); + println!("Immunity Period: {} blocks", immunity_val); + println!("Total Stake: {}", format_tao(info.total_stake)); + } + None => { + print_warning("Could not fetch root network info"); + } + } + + Ok(()) +} + +/// Show root network delegates +async fn show_delegates(cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::delegates::get_delegates; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Fetching delegates..."); + let delegates = get_delegates(&client) + .await + .map_err(|e| anyhow::anyhow!("Failed to fetch delegates: {}", e))?; + sp.finish_and_clear(); + + if delegates.is_empty() { + print_info("No delegates found"); + return Ok(()); + } + + println!("\nRoot Network Delegates"); + println!("═══════════════════════════════════════════════════════════════"); + + let mut table = create_table_with_headers(&[ + "Hotkey", + "Total Stake", + "Take", + "Owner", + ]); + + for delegate in &delegates { + // Calculate total stake across all subnets + let total_stake: u128 = delegate.total_stake.values().sum(); + table.add_row(vec![ + format_address(&delegate.base.hotkey_ss58.to_string()), + format_tao(total_stake), + format!("{:.2}%", delegate.base.take * 100.0), + format_address(&delegate.base.owner_ss58.to_string()), + ]); + } + + println!("{table}"); + println!("\nTotal delegates: {}", delegates.len()); + + Ok(()) +} diff --git a/src/cli/commands/stake.rs b/src/cli/commands/stake.rs new file mode 100644 index 0000000..b7380e6 --- /dev/null +++ b/src/cli/commands/stake.rs @@ -0,0 +1,607 @@ +//! Stake commands for managing TAO delegation. + +use crate::cli::utils::{ + confirm, create_table_with_headers, format_address, format_tao, keypair_to_signer, + print_error, print_info, print_success, print_warning, prompt_password_optional, + resolve_endpoint, spinner, tao_to_rao, +}; +use crate::cli::Cli; +use crate::wallet::Wallet; +use clap::{Args, Subcommand}; + +/// Stake command container +#[derive(Args, Clone)] +pub struct StakeCommand { + #[command(subcommand)] + pub command: StakeCommands, +} + +/// Available stake operations +#[derive(Subcommand, Clone)] +pub enum StakeCommands { + /// Add stake to a hotkey on a subnet + Add { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Subnet ID + #[arg(short, long)] + netuid: u16, + /// Amount in TAO to stake + #[arg(short, long)] + amount: f64, + }, + + /// Remove stake from a hotkey on a subnet + Remove { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Subnet ID + #[arg(short, long)] + netuid: u16, + /// Amount in TAO to unstake + #[arg(short, long)] + amount: f64, + }, + + /// Show stake information + Show { + /// Wallet name (shows all if not specified) + #[arg(short, long)] + wallet: Option, + /// Show all wallets + #[arg(long)] + all: bool, + }, + + /// Move stake between hotkeys or subnets + Move { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Source hotkey name + #[arg(long)] + from_hotkey: String, + /// Destination hotkey name + #[arg(long)] + to_hotkey: String, + /// Source subnet ID + #[arg(long)] + origin_netuid: u16, + /// Destination subnet ID + #[arg(long)] + dest_netuid: u16, + /// Amount in TAO to move + #[arg(short, long)] + amount: f64, + }, + + /// List all stake for a coldkey + List { + /// Wallet name + #[arg(short, long)] + wallet: String, + }, +} + +/// Execute stake commands +pub async fn execute(cmd: StakeCommand, cli: &Cli) -> anyhow::Result<()> { + match cmd.command { + StakeCommands::Add { + wallet, + hotkey, + netuid, + amount, + } => add_stake(&wallet, &hotkey, netuid, amount, cli).await, + StakeCommands::Remove { + wallet, + hotkey, + netuid, + amount, + } => remove_stake(&wallet, &hotkey, netuid, amount, cli).await, + StakeCommands::Show { wallet, all } => show_stake(wallet.as_deref(), all, cli).await, + StakeCommands::Move { + wallet, + from_hotkey, + to_hotkey, + origin_netuid, + dest_netuid, + amount, + } => { + move_stake( + &wallet, + &from_hotkey, + &to_hotkey, + origin_netuid, + dest_netuid, + amount, + cli, + ) + .await + } + StakeCommands::List { wallet } => list_stake(&wallet, cli).await, + } +} + +/// Add stake to a hotkey +async fn add_stake( + wallet_name: &str, + hotkey_name: &str, + netuid: u16, + amount: f64, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::staking::add_stake as stake_add; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + if amount <= 0.0 { + print_error("Amount must be positive"); + return Err(anyhow::anyhow!("Invalid amount")); + } + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + let hotkey_account = AccountId32::from_str(hotkey.ss58_address()) + .map_err(|e| anyhow::anyhow!("Invalid hotkey address: {:?}", e))?; + + let rao_amount = tao_to_rao(amount); + + print_info(&format!( + "Adding stake: {} TAO ({} RAO)", + amount, rao_amount + )); + print_info(&format!("Coldkey: {}", coldkey.ss58_address())); + print_info(&format!("Hotkey: {}", hotkey.ss58_address())); + print_info(&format!("Subnet: {}", netuid)); + + if !confirm("Proceed with staking?", cli.no_prompt) { + print_info("Staking cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Submitting stake transaction..."); + let result = stake_add( + &client, + &signer, + &hotkey_account, + netuid, + rao_amount, + ExtrinsicWait::Finalized, + ) + .await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Stake added successfully!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Failed to add stake: {}", e)); + return Err(anyhow::anyhow!("Staking failed: {}", e)); + } + } + + Ok(()) +} + +/// Remove stake from a hotkey +async fn remove_stake( + wallet_name: &str, + hotkey_name: &str, + netuid: u16, + amount: f64, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::staking::unstake; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + if amount <= 0.0 { + print_error("Amount must be positive"); + return Err(anyhow::anyhow!("Invalid amount")); + } + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + let hotkey_account = AccountId32::from_str(hotkey.ss58_address()) + .map_err(|e| anyhow::anyhow!("Invalid hotkey address: {:?}", e))?; + + let rao_amount = tao_to_rao(amount); + + print_info(&format!( + "Removing stake: {} TAO ({} RAO)", + amount, rao_amount + )); + print_info(&format!("Coldkey: {}", coldkey.ss58_address())); + print_info(&format!("Hotkey: {}", hotkey.ss58_address())); + print_info(&format!("Subnet: {}", netuid)); + + if !confirm("Proceed with unstaking?", cli.no_prompt) { + print_info("Unstaking cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Submitting unstake transaction..."); + let result = unstake( + &client, + &signer, + &hotkey_account, + netuid, + rao_amount, + ExtrinsicWait::Finalized, + ) + .await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Stake removed successfully!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Failed to remove stake: {}", e)); + return Err(anyhow::anyhow!("Unstaking failed: {}", e)); + } + } + + Ok(()) +} + +/// Show stake information for wallets +async fn show_stake(wallet_name: Option<&str>, all: bool, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::stakes::get_stake_info_for_coldkey; + use crate::wallet::list_wallets; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let wallets: Vec = if let Some(name) = wallet_name { + match Wallet::new(name, "default", None) { + Ok(w) => vec![w], + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + } + } else if all { + let names = list_wallets().map_err(|e| anyhow::anyhow!("Failed to list wallets: {}", e))?; + names.iter() + .filter_map(|n| Wallet::new(n, "default", None).ok()) + .collect() + } else { + match Wallet::new("default", "default", None) { + Ok(w) => vec![w], + Err(e) => { + print_error(&format!("Invalid wallet name: {}", e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + } + }; + + if wallets.is_empty() { + print_info("No wallets found"); + return Ok(()); + } + + for wallet in &wallets { + let coldkey_password = prompt_password_optional(&format!( + "Password for '{}' (enter to skip)", + &wallet.name + )); + + let coldkey_addr = match wallet.coldkey_ss58(coldkey_password.as_deref()) { + Ok(addr) => addr, + Err(e) => { + print_warning(&format!("Could not unlock '{}': {}", &wallet.name, e)); + continue; + } + }; + + let coldkey_account = match AccountId32::from_str(&coldkey_addr) { + Ok(acc) => acc, + Err(e) => { + print_warning(&format!("Invalid coldkey address: {:?}", e)); + continue; + } + }; + + let sp = spinner(&format!( + "Fetching stake info for {}...", + format_address(&coldkey_addr) + )); + let stake_result = get_stake_info_for_coldkey(&client, &coldkey_account).await; + sp.finish_and_clear(); + + match stake_result { + Ok(stakes) => { + println!("\nWallet: {} ({})", &wallet.name, format_address(&coldkey_addr)); + + if stakes.is_empty() { + print_info("No stake found"); + continue; + } + + let mut table = + create_table_with_headers(&["Hotkey", "Subnet", "Stake (TAO)"]); + + for stake_info in stakes { + table.add_row(vec![ + format_address(&stake_info.hotkey.to_string()), + stake_info.netuid.to_string(), + format_tao(stake_info.stake), + ]); + } + + println!("{table}"); + } + Err(e) => { + print_warning(&format!( + "Failed to fetch stake for {}: {}", + &wallet.name, + e + )); + } + } + } + + Ok(()) +} + +/// Move stake between hotkeys or subnets +async fn move_stake( + wallet_name: &str, + from_hotkey: &str, + to_hotkey: &str, + origin_netuid: u16, + dest_netuid: u16, + amount: f64, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::staking::move_stake as stake_move; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + if amount <= 0.0 { + print_error("Amount must be positive"); + return Err(anyhow::anyhow!("Invalid amount")); + } + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let from_wallet = match Wallet::new(wallet_name, from_hotkey, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + let to_wallet = match Wallet::new(wallet_name, to_hotkey, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + + if !from_wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = from_wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let from_hotkey_password = prompt_password_optional("Source hotkey password (enter if unencrypted)"); + let from_hk = from_wallet + .hotkey_keypair(from_hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock source hotkey: {}", e))?; + let from_hk_account = AccountId32::from_str(from_hk.ss58_address()) + .map_err(|e| anyhow::anyhow!("Invalid source hotkey address: {:?}", e))?; + + let to_hotkey_password = prompt_password_optional("Destination hotkey password (enter if unencrypted)"); + let to_hk = to_wallet + .hotkey_keypair(to_hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock destination hotkey: {}", e))?; + let to_hk_account = AccountId32::from_str(to_hk.ss58_address()) + .map_err(|e| anyhow::anyhow!("Invalid destination hotkey address: {:?}", e))?; + + let rao_amount = tao_to_rao(amount); + + print_info(&format!("Moving stake: {} TAO ({} RAO)", amount, rao_amount)); + print_info(&format!( + "From: {} (subnet {})", + from_hk.ss58_address(), + origin_netuid + )); + print_info(&format!( + "To: {} (subnet {})", + to_hk.ss58_address(), + dest_netuid + )); + + if !confirm("Proceed with stake move?", cli.no_prompt) { + print_info("Move cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Submitting move stake transaction..."); + let result = stake_move( + &client, + &signer, + &from_hk_account, + &to_hk_account, + origin_netuid, + dest_netuid, + rao_amount, + ExtrinsicWait::Finalized, + ) + .await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Stake moved successfully!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Failed to move stake: {}", e)); + return Err(anyhow::anyhow!("Move stake failed: {}", e)); + } + } + + Ok(()) +} + +/// List all stakes for a coldkey +async fn list_stake(wallet_name: &str, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::stakes::get_stake_info_for_coldkey; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, "default", None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey_addr = wallet + .coldkey_ss58(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let coldkey_account = AccountId32::from_str(&coldkey_addr) + .map_err(|e| anyhow::anyhow!("Invalid coldkey address: {:?}", e))?; + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Fetching stake information..."); + let stakes = get_stake_info_for_coldkey(&client, &coldkey_account) + .await + .map_err(|e| anyhow::anyhow!("Failed to fetch stakes: {}", e))?; + sp.finish_and_clear(); + + println!( + "\nStake for wallet '{}' ({})", + wallet_name, + format_address(&coldkey_addr) + ); + + if stakes.is_empty() { + print_info("No stake found"); + return Ok(()); + } + + let mut table = create_table_with_headers(&["Hotkey", "Subnet", "Stake (TAO)"]); + let mut total_stake: u128 = 0; + + for stake_info in &stakes { + table.add_row(vec![ + format_address(&stake_info.hotkey.to_string()), + stake_info.netuid.to_string(), + format_tao(stake_info.stake), + ]); + total_stake += stake_info.stake; + } + + println!("{table}"); + println!("\nTotal stake: {}", format_tao(total_stake)); + + Ok(()) +} diff --git a/src/cli/commands/subnet.rs b/src/cli/commands/subnet.rs new file mode 100644 index 0000000..988a4a3 --- /dev/null +++ b/src/cli/commands/subnet.rs @@ -0,0 +1,413 @@ +//! Subnet commands for viewing subnet information and registration. + +use crate::cli::utils::{ + confirm, create_table_with_headers, format_address, format_tao, keypair_to_signer, + print_error, print_info, print_success, print_warning, prompt_password_optional, + resolve_endpoint, spinner, +}; +use crate::cli::Cli; +use crate::wallet::Wallet; +use clap::{Args, Subcommand}; + +/// Subnet command container +#[derive(Args, Clone)] +pub struct SubnetCommand { + #[command(subcommand)] + pub command: SubnetCommands, +} + +/// Available subnet operations +#[derive(Subcommand, Clone)] +pub enum SubnetCommands { + /// List all subnets + List, + + /// Show detailed subnet information + Show { + /// Subnet ID + #[arg(short, long)] + netuid: u16, + }, + + /// Show subnet metagraph + Metagraph { + /// Subnet ID + #[arg(short, long)] + netuid: u16, + }, + + /// Register on a subnet + Register { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Subnet ID + #[arg(short, long)] + netuid: u16, + /// Use burned (paid) registration + #[arg(long)] + burned: bool, + }, + + /// Show subnet hyperparameters + Hyperparams { + /// Subnet ID + #[arg(short, long)] + netuid: u16, + }, + + /// Create a new subnet + Create { + /// Wallet name + #[arg(short, long)] + wallet: String, + }, +} + +/// Execute subnet commands +pub async fn execute(cmd: SubnetCommand, cli: &Cli) -> anyhow::Result<()> { + match cmd.command { + SubnetCommands::List => list_subnets(cli).await, + SubnetCommands::Show { netuid } => show_subnet(netuid, cli).await, + SubnetCommands::Metagraph { netuid } => show_metagraph(netuid, cli).await, + SubnetCommands::Register { + wallet, + hotkey, + netuid, + burned, + } => register(&wallet, &hotkey, netuid, burned, cli).await, + SubnetCommands::Hyperparams { netuid } => show_hyperparams(netuid, cli).await, + SubnetCommands::Create { wallet } => create_subnet(&wallet, cli).await, + } +} + +/// List all subnets +async fn list_subnets(cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::subnets::all_subnets; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Fetching subnet list..."); + let subnets = all_subnets(&client) + .await + .map_err(|e| anyhow::anyhow!("Failed to fetch subnets: {}", e))?; + sp.finish_and_clear(); + + if subnets.is_empty() { + print_info("No subnets found"); + return Ok(()); + } + + let mut table = create_table_with_headers(&[ + "NetUID", + "Name", + "Neurons", + "Emission", + ]); + + for info in &subnets { + table.add_row(vec![ + info.netuid.to_string(), + info.name.clone().unwrap_or_else(|| "N/A".to_string()), + info.neuron_count.to_string(), + format!("{:.6}", info.emission), + ]); + } + + println!("\n{table}"); + println!("\nTotal subnets: {}", subnets.len()); + + Ok(()) +} + +/// Show detailed subnet information +async fn show_subnet(netuid: u16, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::subnets::{subnet_info, tempo, difficulty, immunity_period}; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner(&format!("Fetching subnet {} info...", netuid)); + let info = subnet_info(&client, netuid) + .await + .map_err(|e| anyhow::anyhow!("Failed to fetch subnet info: {}", e))?; + + // Fetch additional params + let tempo_val = tempo(&client, netuid).await.unwrap_or(Some(0)).unwrap_or(0); + let diff_val = difficulty(&client, netuid).await.unwrap_or(Some(0)).unwrap_or(0); + let immunity_val = immunity_period(&client, netuid).await.unwrap_or(Some(0)).unwrap_or(0); + sp.finish_and_clear(); + + match info { + Some(info) => { + println!("\nSubnet {}", netuid); + println!("═════════════════════════════════════════"); + println!( + "Name: {}", + info.name.unwrap_or_else(|| "N/A".to_string()) + ); + println!("Neurons: {}", info.neuron_count); + println!("Emission: {:.6}", info.emission); + println!("Total Stake: {}", format_tao(info.total_stake)); + println!("Tempo: {} blocks", tempo_val); + println!("Difficulty: {}", diff_val); + println!("Immunity Period: {} blocks", immunity_val); + } + None => { + print_error(&format!("Subnet {} not found", netuid)); + return Err(anyhow::anyhow!("Subnet not found")); + } + } + + Ok(()) +} + +/// Show subnet metagraph +async fn show_metagraph(netuid: u16, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::metagraph::sync_metagraph; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner(&format!("Syncing metagraph for subnet {}...", netuid)); + let metagraph = sync_metagraph(&client, netuid) + .await + .map_err(|e| anyhow::anyhow!("Failed to sync metagraph: {}", e))?; + sp.finish_and_clear(); + + println!("\nMetagraph for Subnet {}", netuid); + println!("═════════════════════════════════════════════════════════════════"); + + let mut table = create_table_with_headers(&[ + "UID", + "Hotkey", + "Coldkey", + "Stake", + "Trust", + "Consensus", + "Incentive", + "Active", + ]); + + let n = metagraph.n; + let display_count = n.min(50); + for uid in 0..display_count { + if let Some(neuron) = metagraph.neurons.get(&uid) { + table.add_row(vec![ + uid.to_string(), + format_address(&neuron.hotkey.to_string()), + format_address(&neuron.coldkey.to_string()), + format_tao(neuron.total_stake), + format!("{:.4}", neuron.trust), + format!("{:.4}", neuron.consensus), + format!("{:.4}", neuron.incentive), + if neuron.active { "✓" } else { "✗" }.to_string(), + ]); + } + } + + println!("{table}"); + + if n > 50 { + print_info(&format!("Showing first 50 of {} neurons", n)); + } + + println!("\nTotal neurons: {}", n); + println!("Block: {}", metagraph.block); + + Ok(()) +} + +/// Register on a subnet +async fn register( + wallet_name: &str, + hotkey_name: &str, + netuid: u16, + burned: bool, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::registration::{burned_register, register as pow_register}; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + + print_info(&format!("Registering on subnet {}", netuid)); + print_info(&format!("Coldkey: {}", coldkey.ss58_address())); + print_info(&format!("Hotkey: {}", hotkey.ss58_address())); + print_info(&format!( + "Method: {}", + if burned { "Burned (paid)" } else { "PoW" } + )); + + if !confirm("Proceed with registration?", cli.no_prompt) { + print_info("Registration cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let result = if burned { + let sp = spinner("Submitting burned registration..."); + let r = burned_register(&client, &signer, netuid, ExtrinsicWait::Finalized).await; + sp.finish_and_clear(); + r + } else { + let sp = spinner("Performing PoW registration (this may take a while)..."); + // Standard register + let r = pow_register(&client, &signer, netuid, ExtrinsicWait::Finalized).await; + sp.finish_and_clear(); + r + }; + + match result { + Ok(tx_hash) => { + print_success("Registration successful!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Registration failed: {}", e)); + return Err(anyhow::anyhow!("Registration failed: {}", e)); + } + } + + Ok(()) +} + +/// Show subnet hyperparameters +async fn show_hyperparams(netuid: u16, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::subnets::{ + difficulty, immunity_period, max_weight_limit, min_allowed_weights, tempo, + weights_rate_limit, + }; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner(&format!("Fetching hyperparameters for subnet {}...", netuid)); + + // Fetch available hyperparameters + let tempo_val = tempo(&client, netuid).await.ok().flatten().unwrap_or(0); + let difficulty_val = difficulty(&client, netuid).await.ok().flatten().unwrap_or(0); + let immunity_val = immunity_period(&client, netuid).await.ok().flatten().unwrap_or(0); + let max_weights = max_weight_limit(&client, netuid).await.ok().flatten().unwrap_or(0.0); + let min_weights = min_allowed_weights(&client, netuid).await.ok().flatten().unwrap_or(0); + let weights_rate = weights_rate_limit(&client, netuid).await.ok().flatten().unwrap_or(0); + + sp.finish_and_clear(); + + println!("\nHyperparameters for Subnet {}", netuid); + println!("═══════════════════════════════════════════════"); + + let mut table = create_table_with_headers(&["Parameter", "Value"]); + table.add_row(vec!["Tempo", &tempo_val.to_string()]); + table.add_row(vec!["Difficulty", &difficulty_val.to_string()]); + table.add_row(vec!["Immunity Period", &immunity_val.to_string()]); + table.add_row(vec!["Max Weight Limit", &format!("{:.4}", max_weights)]); + table.add_row(vec!["Min Allowed Weights", &min_weights.to_string()]); + table.add_row(vec!["Weights Rate Limit", &weights_rate.to_string()]); + + println!("{table}"); + + Ok(()) +} + +/// Create a new subnet +async fn create_subnet(wallet_name: &str, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, "default", None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + + print_info("Creating new subnet"); + print_info(&format!("Coldkey: {}", coldkey.ss58_address())); + print_warning("This will cost TAO to register a new subnet"); + + if !confirm("Proceed with subnet creation?", cli.no_prompt) { + print_info("Subnet creation cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let _client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + // Note: register_network is not currently available in the validator module + // This would require adding the extrinsic call for RegisterNetwork + print_warning("Subnet registration requires a specialized extrinsic call."); + print_info("Please use the Python btcli or submit the RegisterNetwork extrinsic directly."); + + Ok(()) +} diff --git a/src/cli/commands/wallet.rs b/src/cli/commands/wallet.rs new file mode 100644 index 0000000..cb692e9 --- /dev/null +++ b/src/cli/commands/wallet.rs @@ -0,0 +1,876 @@ +//! Wallet commands for managing coldkeys and hotkeys. + +use crate::cli::utils::{ + confirm, create_table_with_headers, format_address, format_tao, keypair_to_signer, + print_error, print_info, print_success, print_warning, prompt_password, + prompt_password_optional, resolve_endpoint, spinner, tao_to_rao, +}; +use crate::cli::Cli; +use crate::wallet::{Mnemonic, Wallet}; +use clap::{Args, Subcommand}; + +/// Wallet command container +#[derive(Args, Clone)] +pub struct WalletCommand { + #[command(subcommand)] + pub command: WalletCommands, +} + +/// Available wallet operations +#[derive(Subcommand, Clone)] +pub enum WalletCommands { + /// Create a new wallet (coldkey and hotkey) + Create { + /// Wallet name + #[arg(short, long, default_value = "default")] + name: String, + /// Hotkey name + #[arg(short = 'k', long, default_value = "default")] + hotkey: String, + /// Number of mnemonic words (12, 15, 18, 21, 24) + #[arg(long, default_value = "12")] + words: usize, + /// Skip password for coldkey encryption + #[arg(long)] + no_password: bool, + }, + + /// Regenerate wallet from mnemonic phrase + Regen { + /// Wallet name + #[arg(short, long)] + name: String, + /// Mnemonic phrase (space-separated words) + #[arg(long)] + mnemonic: String, + /// Skip password for encryption + #[arg(long)] + no_password: bool, + }, + + /// List all wallets + List { + /// Custom wallet path + #[arg(long)] + path: Option, + }, + + /// Show wallet overview (balances and registrations) + Overview { + /// Wallet name (default: all wallets) + #[arg(short, long)] + name: Option, + /// Show registrations on all subnets + #[arg(long)] + all: bool, + }, + + /// Show wallet balance + Balance { + /// Wallet name + #[arg(short, long)] + name: Option, + /// Show all wallets + #[arg(long)] + all: bool, + }, + + /// Transfer TAO to another address + Transfer { + /// Source wallet name + #[arg(short, long)] + name: String, + /// Destination address (SS58 format) + #[arg(short, long)] + dest: String, + /// Amount in TAO + #[arg(short, long)] + amount: f64, + }, + + /// Create a new hotkey + NewHotkey { + /// Wallet name + #[arg(short, long)] + name: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Number of mnemonic words (12, 15, 18, 21, 24) + #[arg(long, default_value = "12")] + words: usize, + /// Skip password for hotkey encryption + #[arg(long)] + no_password: bool, + }, + + /// Create a new coldkey + NewColdkey { + /// Wallet name + #[arg(short, long)] + name: String, + /// Number of mnemonic words (12, 15, 18, 21, 24) + #[arg(long, default_value = "12")] + words: usize, + /// Skip password for encryption + #[arg(long)] + no_password: bool, + }, + + /// Regenerate coldkey from mnemonic + RegenColdkey { + /// Wallet name + #[arg(short, long)] + name: String, + /// Mnemonic phrase + #[arg(long)] + mnemonic: String, + /// Skip password for encryption + #[arg(long)] + no_password: bool, + }, + + /// Regenerate hotkey from mnemonic + RegenHotkey { + /// Wallet name + #[arg(short, long)] + name: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Mnemonic phrase + #[arg(long)] + mnemonic: String, + /// Skip password for encryption + #[arg(long)] + no_password: bool, + }, + + /// Show wallet addresses + Address { + /// Wallet name + #[arg(short, long, default_value = "default")] + name: String, + /// Hotkey name + #[arg(short = 'k', long, default_value = "default")] + hotkey: String, + }, +} + +/// Execute wallet commands +pub async fn execute(cmd: WalletCommand, cli: &Cli) -> anyhow::Result<()> { + match cmd.command { + WalletCommands::Create { + name, + hotkey, + words, + no_password, + } => create_wallet(&name, &hotkey, words, no_password, cli).await, + WalletCommands::Regen { + name, + mnemonic, + no_password, + } => regen_wallet(&name, &mnemonic, no_password, cli).await, + WalletCommands::List { path } => list_wallets(path.as_deref()).await, + WalletCommands::Overview { name, all } => overview(name.as_deref(), all, cli).await, + WalletCommands::Balance { name, all } => balance(name.as_deref(), all, cli).await, + WalletCommands::Transfer { name, dest, amount } => { + transfer(&name, &dest, amount, cli).await + } + WalletCommands::NewHotkey { + name, + hotkey, + words, + no_password, + } => new_hotkey(&name, &hotkey, words, no_password).await, + WalletCommands::NewColdkey { + name, + words, + no_password, + } => new_coldkey(&name, words, no_password).await, + WalletCommands::RegenColdkey { + name, + mnemonic, + no_password, + } => regen_coldkey(&name, &mnemonic, no_password).await, + WalletCommands::RegenHotkey { + name, + hotkey, + mnemonic, + no_password, + } => regen_hotkey(&name, &hotkey, &mnemonic, no_password).await, + WalletCommands::Address { name, hotkey } => show_address(&name, &hotkey).await, + } +} + +/// Create a new wallet with coldkey and hotkey +async fn create_wallet( + name: &str, + hotkey_name: &str, + words: usize, + no_password: bool, + cli: &Cli, +) -> anyhow::Result<()> { + // Validate word count + if ![12, 15, 18, 21, 24].contains(&words) { + print_error("Word count must be 12, 15, 18, 21, or 24"); + return Err(anyhow::anyhow!("Invalid word count")); + } + + // Check if wallet already exists + let mut wallet = match Wallet::new(name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if wallet.coldkey_exists() { + print_warning(&format!("Wallet '{}' already exists", name)); + if !confirm("Overwrite existing wallet?", cli.no_prompt) { + print_info("Aborted"); + return Ok(()); + } + } + + // Generate mnemonics + let coldkey_mnemonic = Mnemonic::generate_with_words(words) + .map_err(|e| anyhow::anyhow!("Failed to generate coldkey mnemonic: {}", e))?; + + let hotkey_mnemonic = Mnemonic::generate_with_words(words) + .map_err(|e| anyhow::anyhow!("Failed to generate hotkey mnemonic: {}", e))?; + + // Get password for coldkey + let coldkey_password = if no_password { + None + } else { + let pwd = prompt_password("Enter password for coldkey encryption"); + let confirm = prompt_password("Confirm password"); + if pwd != confirm { + print_error("Passwords do not match"); + return Err(anyhow::anyhow!("Password mismatch")); + } + Some(pwd) + }; + + // Create coldkey + let sp = spinner("Creating coldkey..."); + wallet + .create_coldkey( + coldkey_password.as_deref(), + Some(coldkey_mnemonic.phrase()), + false, + ) + .map_err(|e| anyhow::anyhow!("Failed to create coldkey: {}", e))?; + sp.finish_and_clear(); + + // Create hotkey (typically no password) + let sp = spinner("Creating hotkey..."); + wallet + .create_hotkey(None, Some(hotkey_mnemonic.phrase()), false) + .map_err(|e| anyhow::anyhow!("Failed to create hotkey: {}", e))?; + sp.finish_and_clear(); + + // Display results + print_success(&format!("Wallet '{}' created successfully!", name)); + println!(); + + print_warning("IMPORTANT: Save these mnemonic phrases securely!"); + println!(); + + let coldkey_addr = wallet + .coldkey_ss58(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to get coldkey address: {}", e))?; + let hotkey_addr = wallet + .hotkey_ss58(None) + .map_err(|e| anyhow::anyhow!("Failed to get hotkey address: {}", e))?; + + println!("Coldkey address: {}", coldkey_addr); + println!("Coldkey mnemonic: {}", coldkey_mnemonic.phrase()); + println!(); + println!("Hotkey address: {}", hotkey_addr); + println!("Hotkey mnemonic: {}", hotkey_mnemonic.phrase()); + + Ok(()) +} + +/// Regenerate wallet from mnemonic +async fn regen_wallet( + name: &str, + mnemonic: &str, + no_password: bool, + cli: &Cli, +) -> anyhow::Result<()> { + // Validate mnemonic + if !Mnemonic::validate(mnemonic) { + print_error("Invalid mnemonic phrase"); + return Err(anyhow::anyhow!("Invalid mnemonic")); + } + + let mut wallet = match Wallet::new(name, "default", None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + + if wallet.coldkey_exists() { + print_warning(&format!("Wallet '{}' already exists", name)); + if !confirm("Overwrite existing wallet?", cli.no_prompt) { + print_info("Aborted"); + return Ok(()); + } + } + + let password = if no_password { + None + } else { + let pwd = prompt_password("Enter password for encryption"); + let confirm = prompt_password("Confirm password"); + if pwd != confirm { + print_error("Passwords do not match"); + return Err(anyhow::anyhow!("Password mismatch")); + } + Some(pwd) + }; + + let sp = spinner("Regenerating wallet from mnemonic..."); + wallet + .create_coldkey(password.as_deref(), Some(mnemonic), false) + .map_err(|e| anyhow::anyhow!("Failed to regenerate coldkey: {}", e))?; + sp.finish_and_clear(); + + let addr = wallet + .coldkey_ss58(password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to get address: {}", e))?; + + print_success(&format!("Wallet '{}' regenerated successfully!", name)); + println!("Coldkey address: {}", addr); + + Ok(()) +} + +/// List all wallets +async fn list_wallets(path: Option<&str>) -> anyhow::Result<()> { + use crate::wallet::{list_wallets as get_wallet_names, wallet_path}; + use std::path::Path; + + let wallet_names = if let Some(p) = path { + crate::wallet::list_wallets_at(Path::new(p)) + .map_err(|e| anyhow::anyhow!("Failed to list wallets: {}", e))? + } else { + get_wallet_names() + .map_err(|e| anyhow::anyhow!("Failed to list wallets: {}", e))? + }; + + if wallet_names.is_empty() { + print_info("No wallets found"); + return Ok(()); + } + + let mut table = create_table_with_headers(&["Wallet", "Coldkey Path"]); + + for wallet_name in &wallet_names { + table.add_row(vec![ + wallet_name.clone(), + wallet_path(wallet_name).display().to_string(), + ]); + } + + println!("{table}"); + Ok(()) +} + +/// Show wallet overview +async fn overview(name: Option<&str>, _all: bool, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::balances::get_balance; + use crate::queries::stakes::get_stake_info_for_coldkey; + use crate::wallet::list_wallets as get_wallet_names; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let wallets: Vec = if let Some(wallet_name) = name { + match Wallet::new(wallet_name, "default", None) { + Ok(w) => vec![w], + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + } + } else { + let names = get_wallet_names() + .map_err(|e| anyhow::anyhow!("Failed to list wallets: {}", e))?; + names.iter() + .filter_map(|n| Wallet::new(n, "default", None).ok()) + .collect() + }; + + if wallets.is_empty() { + print_info("No wallets found"); + return Ok(()); + } + + let mut table = create_table_with_headers(&["Wallet", "Coldkey", "Free Balance", "Staked"]); + + for wallet in &wallets { + let password = prompt_password_optional(&format!( + "Password for '{}' (enter to skip)", + &wallet.name + )); + + let coldkey_addr = match wallet.coldkey_ss58(password.as_deref()) { + Ok(addr) => addr, + Err(e) => { + print_warning(&format!( + "Could not unlock '{}': {}", + &wallet.name, + e + )); + continue; + } + }; + + let sp = spinner(&format!("Fetching balance for {}...", format_address(&coldkey_addr))); + + // Parse SS58 to AccountId32 + let account = AccountId32::from_str(&coldkey_addr) + .map_err(|e| anyhow::anyhow!("Invalid SS58 address: {}", e))?; + + let balance_result = get_balance(&client, &account).await; + let stake_result = get_stake_info_for_coldkey(&client, &account).await; + sp.finish_and_clear(); + + let free = balance_result.unwrap_or(0); + let staked: u128 = stake_result + .map(|stakes| stakes.iter().map(|s| s.stake).sum()) + .unwrap_or(0); + + table.add_row(vec![ + wallet.name.to_string(), + format_address(&coldkey_addr), + format_tao(free), + format_tao(staked), + ]); + } + + println!("\n{table}"); + Ok(()) +} + +/// Show wallet balance +async fn balance(name: Option<&str>, all: bool, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::balances::get_balance; + use crate::queries::stakes::get_stake_info_for_coldkey; + use crate::wallet::list_wallets as get_wallet_names; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let wallets: Vec = if let Some(wallet_name) = name { + match Wallet::new(wallet_name, "default", None) { + Ok(w) => vec![w], + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + } + } else if all { + let names = get_wallet_names() + .map_err(|e| anyhow::anyhow!("Failed to list wallets: {}", e))?; + names.iter() + .filter_map(|n| Wallet::new(n, "default", None).ok()) + .collect() + } else { + match Wallet::new("default", "default", None) { + Ok(w) => vec![w], + Err(e) => { + print_error(&format!("Invalid wallet name: {}", e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + } + }; + + if wallets.is_empty() { + print_info("No wallets found"); + return Ok(()); + } + + let mut table = + create_table_with_headers(&["Wallet", "Coldkey", "Free Balance", "Staked", "Total"]); + + for wallet in &wallets { + let password = prompt_password_optional(&format!( + "Password for '{}' (enter to skip)", + &wallet.name + )); + + let coldkey_addr = match wallet.coldkey_ss58(password.as_deref()) { + Ok(addr) => addr, + Err(e) => { + print_warning(&format!("Could not unlock '{}': {}", &wallet.name, e)); + continue; + } + }; + + let sp = spinner(&format!("Fetching balance for {}...", format_address(&coldkey_addr))); + + // Parse SS58 to AccountId32 + let account = AccountId32::from_str(&coldkey_addr) + .map_err(|e| anyhow::anyhow!("Invalid SS58 address: {}", e))?; + + let balance_result = get_balance(&client, &account).await; + let stake_result = get_stake_info_for_coldkey(&client, &account).await; + sp.finish_and_clear(); + + let free = balance_result.unwrap_or(0); + let staked: u128 = stake_result + .map(|stakes| stakes.iter().map(|s| s.stake).sum()) + .unwrap_or(0); + let total = free + staked; + + table.add_row(vec![ + wallet.name.to_string(), + format_address(&coldkey_addr), + format_tao(free), + format_tao(staked), + format_tao(total), + ]); + } + + println!("\n{table}"); + Ok(()) +} + +/// Transfer TAO to another address +async fn transfer(name: &str, dest: &str, amount: f64, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::transfer::transfer as do_transfer; + use sp_core::crypto::AccountId32; + use std::str::FromStr; + + if amount <= 0.0 { + print_error("Amount must be positive"); + return Err(anyhow::anyhow!("Invalid amount")); + } + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(name, "default", None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let dest_account = AccountId32::from_str(dest) + .map_err(|e| anyhow::anyhow!("Invalid destination address: {:?}", e))?; + + let rao_amount = tao_to_rao(amount); + + print_info(&format!( + "Transfer {} TAO ({} RAO)", + amount, rao_amount + )); + print_info(&format!("From: {}", coldkey.ss58_address())); + print_info(&format!("To: {}", dest)); + + if !confirm("Proceed with transfer?", cli.no_prompt) { + print_info("Transfer cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Submitting transfer..."); + let result = do_transfer(&client, &signer, &dest_account, rao_amount, true, ExtrinsicWait::Finalized).await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Transfer successful!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Transfer failed: {}", e)); + return Err(anyhow::anyhow!("Transfer failed: {}", e)); + } + } + + Ok(()) +} + +/// Create a new hotkey +async fn new_hotkey( + name: &str, + hotkey_name: &str, + words: usize, + no_password: bool, +) -> anyhow::Result<()> { + if ![12, 15, 18, 21, 24].contains(&words) { + print_error("Word count must be 12, 15, 18, 21, or 24"); + return Err(anyhow::anyhow!("Invalid word count")); + } + + let mut wallet = match Wallet::new(name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' does not exist", name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let mnemonic = Mnemonic::generate_with_words(words) + .map_err(|e| anyhow::anyhow!("Failed to generate mnemonic: {}", e))?; + + let password = if no_password { + None + } else { + let pwd = prompt_password("Enter password for hotkey encryption (enter for none)"); + if pwd.is_empty() { + None + } else { + Some(pwd) + } + }; + + let sp = spinner("Creating hotkey..."); + wallet + .create_hotkey(password.as_deref(), Some(mnemonic.phrase()), false) + .map_err(|e| anyhow::anyhow!("Failed to create hotkey: {}", e))?; + sp.finish_and_clear(); + + let addr = wallet + .hotkey_ss58(password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to get hotkey address: {}", e))?; + + print_success(&format!( + "Hotkey '{}' created for wallet '{}'", + hotkey_name, name + )); + println!(); + print_warning("Save this mnemonic phrase securely!"); + println!("Hotkey address: {}", addr); + println!("Hotkey mnemonic: {}", mnemonic.phrase()); + + Ok(()) +} + +/// Create a new coldkey +async fn new_coldkey(name: &str, words: usize, no_password: bool) -> anyhow::Result<()> { + if ![12, 15, 18, 21, 24].contains(&words) { + print_error("Word count must be 12, 15, 18, 21, or 24"); + return Err(anyhow::anyhow!("Invalid word count")); + } + + let mut wallet = match Wallet::new(name, "default", None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + + let mnemonic = Mnemonic::generate_with_words(words) + .map_err(|e| anyhow::anyhow!("Failed to generate mnemonic: {}", e))?; + + let password = if no_password { + None + } else { + let pwd = prompt_password("Enter password for encryption"); + let confirm = prompt_password("Confirm password"); + if pwd != confirm { + print_error("Passwords do not match"); + return Err(anyhow::anyhow!("Password mismatch")); + } + Some(pwd) + }; + + let sp = spinner("Creating coldkey..."); + wallet + .create_coldkey(password.as_deref(), Some(mnemonic.phrase()), false) + .map_err(|e| anyhow::anyhow!("Failed to create coldkey: {}", e))?; + sp.finish_and_clear(); + + let addr = wallet + .coldkey_ss58(password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to get coldkey address: {}", e))?; + + print_success(&format!("Coldkey '{}' created!", name)); + println!(); + print_warning("IMPORTANT: Save this mnemonic phrase securely!"); + println!("Coldkey address: {}", addr); + println!("Coldkey mnemonic: {}", mnemonic.phrase()); + + Ok(()) +} + +/// Regenerate coldkey from mnemonic +async fn regen_coldkey(name: &str, mnemonic: &str, no_password: bool) -> anyhow::Result<()> { + if !Mnemonic::validate(mnemonic) { + print_error("Invalid mnemonic phrase"); + return Err(anyhow::anyhow!("Invalid mnemonic")); + } + + let mut wallet = match Wallet::new(name, "default", None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + + let password = if no_password { + None + } else { + let pwd = prompt_password("Enter password for encryption"); + let confirm = prompt_password("Confirm password"); + if pwd != confirm { + print_error("Passwords do not match"); + return Err(anyhow::anyhow!("Password mismatch")); + } + Some(pwd) + }; + + let sp = spinner("Regenerating coldkey..."); + wallet + .create_coldkey(password.as_deref(), Some(mnemonic), false) + .map_err(|e| anyhow::anyhow!("Failed to regenerate coldkey: {}", e))?; + sp.finish_and_clear(); + + let addr = wallet + .coldkey_ss58(password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to get coldkey address: {}", e))?; + + print_success(&format!("Coldkey '{}' regenerated!", name)); + println!("Coldkey address: {}", addr); + + Ok(()) +} + +/// Regenerate hotkey from mnemonic +async fn regen_hotkey( + name: &str, + hotkey_name: &str, + mnemonic: &str, + no_password: bool, +) -> anyhow::Result<()> { + if !Mnemonic::validate(mnemonic) { + print_error("Invalid mnemonic phrase"); + return Err(anyhow::anyhow!("Invalid mnemonic")); + } + + let mut wallet = match Wallet::new(name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' does not exist", name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let password = if no_password { + None + } else { + let pwd = prompt_password("Enter password for encryption (enter for none)"); + if pwd.is_empty() { + None + } else { + Some(pwd) + } + }; + + let sp = spinner("Regenerating hotkey..."); + wallet + .create_hotkey(password.as_deref(), Some(mnemonic), false) + .map_err(|e| anyhow::anyhow!("Failed to regenerate hotkey: {}", e))?; + sp.finish_and_clear(); + + let addr = wallet + .hotkey_ss58(password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to get hotkey address: {}", e))?; + + print_success(&format!( + "Hotkey '{}' regenerated for wallet '{}'!", + hotkey_name, name + )); + println!("Hotkey address: {}", addr); + + Ok(()) +} + +/// Show wallet addresses +async fn show_address(name: &str, hotkey_name: &str) -> anyhow::Result<()> { + let wallet = match Wallet::new(name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + + let coldkey_addr = wallet.coldkey_ss58(coldkey_password.as_deref()); + let hotkey_addr = wallet.hotkey_ss58(hotkey_password.as_deref()); + + println!(); + match coldkey_addr { + Ok(addr) => println!("Coldkey address: {}", addr), + Err(e) => print_warning(&format!("Could not get coldkey address: {}", e)), + } + + match hotkey_addr { + Ok(addr) => println!("Hotkey address: {}", addr), + Err(e) => print_warning(&format!("Could not get hotkey address: {}", e)), + } + + Ok(()) +} diff --git a/src/cli/commands/weights.rs b/src/cli/commands/weights.rs new file mode 100644 index 0000000..11d96b7 --- /dev/null +++ b/src/cli/commands/weights.rs @@ -0,0 +1,473 @@ +//! Weight commands for commit-reveal and direct weight setting. + +use crate::cli::utils::{ + confirm, create_table_with_headers, format_address, keypair_to_signer, parse_f64_list, + parse_u16_list, print_error, print_info, print_success, print_warning, + prompt_password_optional, resolve_endpoint, spinner, +}; +use crate::cli::Cli; +use crate::wallet::Wallet; +use clap::{Args, Subcommand}; + +/// Weights command container +#[derive(Args, Clone)] +pub struct WeightsCommand { + #[command(subcommand)] + pub command: WeightsCommands, +} + +/// Available weight operations +#[derive(Subcommand, Clone)] +pub enum WeightsCommands { + /// Commit weights (for commit-reveal) + Commit { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Subnet ID + #[arg(short, long)] + netuid: u16, + /// Target UIDs (comma-separated, e.g., "1,2,3") + #[arg(long)] + uids: String, + /// Weights (comma-separated, e.g., "0.3,0.5,0.2") + #[arg(long)] + weights: String, + }, + + /// Reveal committed weights + Reveal { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Subnet ID + #[arg(short, long)] + netuid: u16, + }, + + /// Set weights directly (no commit-reveal) + Set { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + /// Subnet ID + #[arg(short, long)] + netuid: u16, + /// Target UIDs (comma-separated, e.g., "1,2,3") + #[arg(long)] + uids: String, + /// Weights (comma-separated, e.g., "0.3,0.5,0.2") + #[arg(long)] + weights: String, + }, + + /// Check current weight information + Info { + /// Subnet ID + #[arg(short, long)] + netuid: u16, + /// Hotkey address to check (optional) + #[arg(long)] + hotkey: Option, + }, + + /// Show pending commits + Pending { + /// Wallet name + #[arg(short, long)] + wallet: String, + /// Hotkey name + #[arg(short = 'k', long)] + hotkey: String, + }, +} + +/// Execute weight commands +pub async fn execute(cmd: WeightsCommand, cli: &Cli) -> anyhow::Result<()> { + match cmd.command { + WeightsCommands::Commit { + wallet, + hotkey, + netuid, + uids, + weights, + } => commit_weights(&wallet, &hotkey, netuid, &uids, &weights, cli).await, + WeightsCommands::Reveal { + wallet, + hotkey, + netuid, + } => reveal_weights(&wallet, &hotkey, netuid, cli).await, + WeightsCommands::Set { + wallet, + hotkey, + netuid, + uids, + weights, + } => set_weights(&wallet, &hotkey, netuid, &uids, &weights, cli).await, + WeightsCommands::Info { netuid, hotkey } => weight_info(netuid, hotkey.as_deref(), cli).await, + WeightsCommands::Pending { wallet, hotkey } => pending_commits(&wallet, &hotkey, cli).await, + } +} + +/// Commit weights (for commit-reveal protocol) +async fn commit_weights( + wallet_name: &str, + hotkey_name: &str, + netuid: u16, + uids_str: &str, + weights_str: &str, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::utils::crypto::generate_subtensor_commit_hash; + use crate::validator::weights::commit_weights as raw_commit_weights; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + // Parse UIDs and weights + let uids = parse_u16_list(uids_str)?; + let weight_values = parse_f64_list(weights_str)?; + + if uids.len() != weight_values.len() { + print_error("Number of UIDs must match number of weights"); + return Err(anyhow::anyhow!("Mismatched UIDs and weights")); + } + + // Normalize and convert weights to u16 + let sum: f64 = weight_values.iter().sum(); + if sum <= 0.0 { + print_error("Weights must sum to a positive value"); + return Err(anyhow::anyhow!("Invalid weights")); + } + + let normalized: Vec = weight_values + .iter() + .map(|w| ((w / sum) * 65535.0) as u16) + .collect(); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + let signer = keypair_to_signer(&coldkey); + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + + print_info(&format!("Committing weights for subnet {}", netuid)); + print_info(&format!("Hotkey: {}", hotkey.ss58_address())); + print_info(&format!("UIDs: {:?}", uids)); + print_info(&format!("Weights (normalized u16): {:?}", normalized)); + + if !confirm("Proceed with weight commit?", cli.no_prompt) { + print_info("Weight commit cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + // Generate commit hash + // Get the hotkey's public key bytes + use sp_core::Pair; + let hotkey_pubkey: [u8; 32] = hotkey.pair().public().0; + + // Generate random salt + let salt: Vec = (0..8).map(|_| rand::random::()).collect(); + + let commit_hash_bytes = generate_subtensor_commit_hash( + &hotkey_pubkey, + netuid, + None, // mechanism_id + &uids, + &normalized, + &salt, + 0, // version_key + ); + let commit_hash = hex::encode(commit_hash_bytes); + + let sp = spinner("Submitting weight commit..."); + let result = raw_commit_weights(&client, &signer, netuid, &commit_hash, ExtrinsicWait::Finalized).await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Weights committed successfully!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + print_warning("Remember to reveal your weights before the reveal period ends!"); + } + Err(e) => { + print_error(&format!("Failed to commit weights: {}", e)); + return Err(anyhow::anyhow!("Weight commit failed: {}", e)); + } + } + + Ok(()) +} + +/// Reveal previously committed weights +async fn reveal_weights( + wallet_name: &str, + hotkey_name: &str, + netuid: u16, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let _coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + + print_info(&format!("Revealing weights for subnet {}", netuid)); + print_info(&format!("Hotkey: {}", hotkey.ss58_address())); + + if !confirm("Proceed with weight reveal?", cli.no_prompt) { + print_info("Weight reveal cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let _client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + // Note: reveal_weights requires the original uids, weights, and salt that were committed + // This information is typically stored locally when commit is performed + print_warning("Weight reveal requires the original committed data (uids, weights, salt)."); + print_info("Use the high-level Subtensor API for automatic commit/reveal tracking."); + print_info("Or use 'btcli weights set' for direct weight setting if commit-reveal is disabled."); + + Ok(()) +} + +/// Set weights directly (no commit-reveal) +async fn set_weights( + wallet_name: &str, + hotkey_name: &str, + netuid: u16, + uids_str: &str, + weights_str: &str, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::{BittensorClient, ExtrinsicWait}; + use crate::validator::weights::set_weights as raw_set_weights; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + // Parse UIDs and weights + let uids = parse_u16_list(uids_str)?; + let weight_values = parse_f64_list(weights_str)?; + + if uids.len() != weight_values.len() { + print_error("Number of UIDs must match number of weights"); + return Err(anyhow::anyhow!("Mismatched UIDs and weights")); + } + + // Normalize and convert weights to f32 for the API + let sum: f64 = weight_values.iter().sum(); + if sum <= 0.0 { + print_error("Weights must sum to a positive value"); + return Err(anyhow::anyhow!("Invalid weights")); + } + + let normalized_f32: Vec = weight_values + .iter() + .map(|w| (*w / sum) as f32) + .collect(); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let coldkey_password = prompt_password_optional("Coldkey password (enter if unencrypted)"); + let _coldkey = wallet + .coldkey_keypair(coldkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock coldkey: {}", e))?; + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + let signer = keypair_to_signer(&hotkey); + + // Convert UIDs to u64 + let uids_u64: Vec = uids.iter().map(|u| *u as u64).collect(); + + print_info(&format!("Setting weights for subnet {}", netuid)); + print_info(&format!("Hotkey: {}", hotkey.ss58_address())); + print_info(&format!("UIDs: {:?}", uids)); + print_info(&format!("Weights (normalized): {:?}", normalized_f32)); + + if !confirm("Proceed with setting weights?", cli.no_prompt) { + print_info("Weight setting cancelled"); + return Ok(()); + } + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner("Submitting weights..."); + let result = raw_set_weights( + &client, + &signer, + netuid, + &uids_u64, + &normalized_f32, + Some(0), // version_key + ExtrinsicWait::Finalized + ).await; + sp.finish_and_clear(); + + match result { + Ok(tx_hash) => { + print_success("Weights set successfully!"); + print_info(&format!("Transaction hash: {}", tx_hash)); + } + Err(e) => { + print_error(&format!("Failed to set weights: {}", e)); + return Err(anyhow::anyhow!("Set weights failed: {}", e)); + } + } + + Ok(()) +} + +/// Show weight-related information for a subnet +async fn weight_info(netuid: u16, _hotkey: Option<&str>, cli: &Cli) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + use crate::queries::subnets::{commit_reveal_enabled, tempo, weights_rate_limit}; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + let sp = spinner(&format!("Fetching weight info for subnet {}...", netuid)); + + // Fetch weight-related parameters + let tempo_val = tempo(&client, netuid).await.ok().flatten().unwrap_or(0); + let rate_limit = weights_rate_limit(&client, netuid).await.ok().flatten().unwrap_or(0); + let cr_enabled = commit_reveal_enabled(&client, netuid).await.unwrap_or(false); + + sp.finish_and_clear(); + + println!("\nWeight Information for Subnet {}", netuid); + println!("═══════════════════════════════════════════════"); + + let mut table = create_table_with_headers(&["Parameter", "Value"]); + table.add_row(vec!["Tempo", &tempo_val.to_string()]); + table.add_row(vec!["Weights Rate Limit", &rate_limit.to_string()]); + table.add_row(vec![ + "Commit-Reveal Enabled", + if cr_enabled { "Yes" } else { "No" }, + ]); + + println!("{table}"); + + Ok(()) +} + +/// Show pending weight commits +async fn pending_commits( + wallet_name: &str, + hotkey_name: &str, + cli: &Cli, +) -> anyhow::Result<()> { + use crate::chain::BittensorClient; + + let endpoint = resolve_endpoint(&cli.network, cli.endpoint.as_deref()); + + let wallet = match Wallet::new(wallet_name, hotkey_name, None) { + Ok(w) => w, + Err(e) => { + print_error(&format!("Invalid wallet name '{}': {}", wallet_name, e)); + return Err(anyhow::anyhow!("Invalid wallet name: {}", e)); + } + }; + if !wallet.coldkey_exists() { + print_error(&format!("Wallet '{}' not found", wallet_name)); + return Err(anyhow::anyhow!("Wallet not found")); + } + + let hotkey_password = prompt_password_optional("Hotkey password (enter if unencrypted)"); + let hotkey = wallet + .hotkey_keypair(hotkey_password.as_deref()) + .map_err(|e| anyhow::anyhow!("Failed to unlock hotkey: {}", e))?; + + let sp = spinner(&format!("Connecting to {}...", endpoint)); + let _client = BittensorClient::new(&endpoint) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect: {}", e))?; + sp.finish_and_clear(); + + println!("\nPending Commits for {}", format_address(hotkey.ss58_address())); + println!("═══════════════════════════════════════════════"); + + // Note: Pending commits are typically stored locally by the application + // since the chain only stores the hash. Display any local state if available. + print_info("Pending commit tracking requires local state management."); + print_info("Use the Subtensor high-level API for automatic commit tracking."); + + Ok(()) +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..c155490 --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,78 @@ +//! CLI tool for Bittensor (btcli equivalent) +//! +//! This module provides a command-line interface for interacting with the +//! Bittensor network, similar to the Python btcli tool. +//! +//! # Commands +//! +//! - `wallet` - Wallet creation, management, and operations +//! - `stake` - Stake management (add, remove, move) +//! - `subnet` - Subnet information and registration +//! - `root` - Root network operations +//! - `weights` - Weight commit, reveal, and set operations + +use clap::{Parser, Subcommand}; + +pub mod commands; +pub mod utils; + +/// Bittensor CLI - Rust implementation +#[derive(Parser)] +#[command(name = "btcli")] +#[command(author = "Cortex Foundation")] +#[command(version = "0.1.0")] +#[command(about = "Bittensor CLI - Rust implementation", long_about = None)] +#[command(propagate_version = true)] +pub struct Cli { + #[command(subcommand)] + pub command: Commands, + + /// Network to connect to (finney, test, local, or custom URL) + #[arg(short, long, default_value = "finney", global = true)] + pub network: String, + + /// Custom RPC endpoint (overrides --network) + #[arg(long, global = true)] + pub endpoint: Option, + + /// Don't prompt for confirmations (auto-approve) + #[arg(long, global = true)] + pub no_prompt: bool, +} + +/// Available CLI commands +#[derive(Subcommand)] +pub enum Commands { + /// Wallet operations (create, list, transfer, etc.) + #[command(alias = "w")] + Wallet(commands::wallet::WalletCommand), + + /// Stake operations (add, remove, move stake) + #[command(alias = "s")] + Stake(commands::stake::StakeCommand), + + /// Subnet operations (list, info, register) + #[command(alias = "sn")] + Subnet(commands::subnet::SubnetCommand), + + /// Root network operations + #[command(alias = "r")] + Root(commands::root::RootCommand), + + /// Weight operations (commit, reveal, set) + #[command(alias = "wt")] + Weights(commands::weights::WeightsCommand), +} + +/// Run the CLI application +pub async fn run() -> anyhow::Result<()> { + let cli = Cli::parse(); + + match &cli.command { + Commands::Wallet(cmd) => commands::wallet::execute(cmd.clone(), &cli).await, + Commands::Stake(cmd) => commands::stake::execute(cmd.clone(), &cli).await, + Commands::Subnet(cmd) => commands::subnet::execute(cmd.clone(), &cli).await, + Commands::Root(cmd) => commands::root::execute(cmd.clone(), &cli).await, + Commands::Weights(cmd) => commands::weights::execute(cmd.clone(), &cli).await, + } +} diff --git a/src/cli/utils.rs b/src/cli/utils.rs new file mode 100644 index 0000000..b9a787e --- /dev/null +++ b/src/cli/utils.rs @@ -0,0 +1,252 @@ +//! CLI utility functions for terminal interaction and formatting. + +use comfy_table::{presets::UTF8_FULL, ContentArrangement, Table}; +use console::{style, Term}; +use dialoguer::{Confirm, Input, Password}; +use indicatif::{ProgressBar, ProgressStyle}; +use std::time::Duration; + +/// Prompt for confirmation with default behavior based on `no_prompt` flag. +/// If `no_prompt` is true, returns true without prompting. +pub fn confirm(message: &str, no_prompt: bool) -> bool { + if no_prompt { + return true; + } + + Confirm::new() + .with_prompt(message) + .default(false) + .interact() + .unwrap_or(false) +} + +/// Prompt for password input (hidden characters). +pub fn prompt_password(message: &str) -> String { + Password::new() + .with_prompt(message) + .interact() + .unwrap_or_default() +} + +/// Prompt for optional password input. Returns None if empty. +pub fn prompt_password_optional(message: &str) -> Option { + let password = Password::new() + .with_prompt(message) + .allow_empty_password(true) + .interact() + .unwrap_or_default(); + + if password.is_empty() { + None + } else { + Some(password) + } +} + +/// Prompt for text input with a default value. +pub fn prompt_input(message: &str) -> String { + Input::new() + .with_prompt(message) + .interact_text() + .unwrap_or_default() +} + +/// Prompt for text input with a default value. +pub fn prompt_input_with_default(message: &str, default: &str) -> String { + Input::new() + .with_prompt(message) + .default(default.to_string()) + .interact_text() + .unwrap_or_else(|_| default.to_string()) +} + +/// Create a spinner progress bar with message. +pub fn spinner(message: &str) -> ProgressBar { + let pb = ProgressBar::new_spinner(); + pb.set_style( + ProgressStyle::default_spinner() + .tick_chars("⠁⠂⠄⡀⢀⠠⠐⠈ ") + .template("{spinner:.blue} {msg}") + .expect("valid template"), + ); + pb.set_message(message.to_string()); + pb.enable_steady_tick(Duration::from_millis(100)); + pb +} + +/// Print success message in green. +pub fn print_success(message: &str) { + let term = Term::stdout(); + let _ = term.write_line(&format!("{} {}", style("✓").green().bold(), message)); +} + +/// Print error message in red. +pub fn print_error(message: &str) { + let term = Term::stderr(); + let _ = term.write_line(&format!("{} {}", style("✗").red().bold(), message)); +} + +/// Print info message in blue. +pub fn print_info(message: &str) { + let term = Term::stdout(); + let _ = term.write_line(&format!("{} {}", style("ℹ").blue().bold(), message)); +} + +/// Print warning message in yellow. +pub fn print_warning(message: &str) { + let term = Term::stdout(); + let _ = term.write_line(&format!("{} {}", style("⚠").yellow().bold(), message)); +} + +/// Format RAO balance as TAO (1 TAO = 1e9 RAO). +pub fn format_tao(rao: u128) -> String { + const RAO_PER_TAO: u128 = 1_000_000_000; + let whole = rao / RAO_PER_TAO; + let fraction = rao % RAO_PER_TAO; + format!("{}.{:09} τ", whole, fraction) +} + +/// Format TAO as RAO for display (preserves decimal precision). +pub fn tao_to_rao(tao: f64) -> u128 { + const RAO_PER_TAO: f64 = 1_000_000_000.0; + (tao * RAO_PER_TAO) as u128 +} + +/// Format SS58 address (truncated for display). +/// Shows first 8 and last 8 characters with "..." in between. +pub fn format_address(address: &str) -> String { + if address.len() <= 18 { + return address.to_string(); + } + format!("{}...{}", &address[..8], &address[address.len() - 8..]) +} + +/// Format SS58 address for full display. +pub fn format_address_full(address: &str) -> String { + address.to_string() +} + +/// Create a styled table for CLI output. +pub fn create_table() -> Table { + let mut table = Table::new(); + table + .load_preset(UTF8_FULL) + .set_content_arrangement(ContentArrangement::Dynamic); + table +} + +/// Create a table with custom headers. +pub fn create_table_with_headers(headers: &[&str]) -> Table { + let mut table = create_table(); + table.set_header(headers.iter().map(|h| style(*h).bold().to_string())); + table +} + +/// Parse comma-separated list of u16 values. +pub fn parse_u16_list(input: &str) -> anyhow::Result> { + input + .split(',') + .map(|s| { + s.trim() + .parse::() + .map_err(|e| anyhow::anyhow!("Invalid u16 value '{}': {}", s.trim(), e)) + }) + .collect() +} + +/// Parse comma-separated list of f64 values. +pub fn parse_f64_list(input: &str) -> anyhow::Result> { + input + .split(',') + .map(|s| { + s.trim() + .parse::() + .map_err(|e| anyhow::anyhow!("Invalid f64 value '{}': {}", s.trim(), e)) + }) + .collect() +} + +/// Get network endpoint from network name or custom endpoint. +pub fn resolve_endpoint(network: &str, custom_endpoint: Option<&str>) -> String { + if let Some(endpoint) = custom_endpoint { + return endpoint.to_string(); + } + + match network.to_lowercase().as_str() { + "finney" => "wss://entrypoint-finney.opentensor.ai:443".to_string(), + "test" | "testnet" => "wss://test.finney.opentensor.ai:443".to_string(), + "local" | "localhost" => "ws://127.0.0.1:9944".to_string(), + "archive" => "wss://archive.chain.opentensor.ai:443".to_string(), + _ => network.to_string(), + } +} + +/// Format duration for display. +pub fn format_duration(seconds: u64) -> String { + if seconds < 60 { + format!("{}s", seconds) + } else if seconds < 3600 { + let mins = seconds / 60; + let secs = seconds % 60; + format!("{}m {}s", mins, secs) + } else { + let hours = seconds / 3600; + let mins = (seconds % 3600) / 60; + format!("{}h {}m", hours, mins) + } +} + +/// Validate SS58 address format. +pub fn is_valid_ss58(address: &str) -> bool { + if address.len() < 46 || address.len() > 48 { + return false; + } + address.chars().all(|c| c.is_alphanumeric()) +} + +/// Create a BittensorSigner from a wallet Keypair +pub fn keypair_to_signer(keypair: &crate::wallet::Keypair) -> crate::chain::BittensorSigner { + crate::chain::create_signer(keypair.pair().clone()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_tao() { + assert_eq!(format_tao(0), "0.000000000 τ"); + assert_eq!(format_tao(1_000_000_000), "1.000000000 τ"); + assert_eq!(format_tao(1_500_000_000), "1.500000000 τ"); + assert_eq!(format_tao(123_456_789_012), "123.456789012 τ"); + } + + #[test] + fn test_format_address() { + let addr = "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY"; + assert_eq!(format_address(addr), "5GrwvaEF...oHGKutQY"); + + let short = "5GrwvaEF"; + assert_eq!(format_address(short), "5GrwvaEF"); + } + + #[test] + fn test_parse_u16_list() { + assert_eq!(parse_u16_list("1,2,3").unwrap(), vec![1, 2, 3]); + assert_eq!(parse_u16_list("1, 2, 3").unwrap(), vec![1, 2, 3]); + assert!(parse_u16_list("1,invalid").is_err()); + } + + #[test] + fn test_resolve_endpoint() { + assert_eq!( + resolve_endpoint("finney", None), + "wss://entrypoint-finney.opentensor.ai:443" + ); + assert_eq!(resolve_endpoint("local", None), "ws://127.0.0.1:9944"); + assert_eq!( + resolve_endpoint("finney", Some("ws://custom:9944")), + "ws://custom:9944" + ); + } +} diff --git a/src/crv4/encryption.rs b/src/crv4/encryption.rs index ada91a1..45d0a08 100644 --- a/src/crv4/encryption.rs +++ b/src/crv4/encryption.rs @@ -14,6 +14,7 @@ use tle::{ stream_ciphers::AESGCMStreamCipherProvider, tlock::tle, }; use w3f_bls::EngineBLS; +use zeroize::Zeroize; /// Encrypt weights payload for CRv4 commit /// @@ -73,7 +74,7 @@ pub fn encrypt_for_round(data: &[u8], reveal_round: u64) -> Result> { // Generate ephemeral secret key (random 32 bytes) let rng = ChaCha20Rng::from_entropy(); - let esk: [u8; 32] = rand::random(); + let mut esk: [u8; 32] = rand::random(); // Encrypt using TLE let ciphertext = tle::( @@ -81,6 +82,9 @@ pub fn encrypt_for_round(data: &[u8], reveal_round: u64) -> Result> { ) .map_err(|e| anyhow::anyhow!("TLE encryption failed: {:?}", e))?; + // SECURITY: Zeroize ephemeral secret key after use to prevent leakage + esk.zeroize(); + // Serialize compressed let mut commit_bytes = Vec::new(); ciphertext diff --git a/src/dendrite/client.rs b/src/dendrite/client.rs new file mode 100644 index 0000000..db8e2e2 --- /dev/null +++ b/src/dendrite/client.rs @@ -0,0 +1,663 @@ +//! Dendrite HTTP client for Bittensor network communication +//! +//! The Dendrite client is responsible for making HTTP requests to Axon servers. +//! It handles request signing, connection pooling, timeouts, and response parsing. + +use crate::dendrite::request::{DendriteRequest, RequestError}; +use crate::dendrite::response::{build_error_synapse, status_codes, DendriteResponse, ResponseError}; +use crate::dendrite::streaming::{StreamError, StreamingResponse, StreamingSynapse}; +use crate::types::{AxonInfo, Synapse, TerminalInfo}; +use crate::utils::ss58::AccountId32ToSS58; +use futures::Stream; +use reqwest::Client; +use sp_core::{sr25519, Pair}; +use std::time::{Duration, Instant}; +use thiserror::Error; +use uuid::Uuid; + +/// Default timeout for Dendrite requests (12 seconds, matching Python SDK) +pub const DEFAULT_TIMEOUT_SECS: u64 = 12; + +/// Default Dendrite version +pub const DEFAULT_DENDRITE_VERSION: u64 = 100; + +/// Errors that can occur during Dendrite operations +#[derive(Debug, Error)] +pub enum DendriteError { + #[error("Request building error: {0}")] + Request(#[from] RequestError), + #[error("Response error: {0}")] + Response(#[from] ResponseError), + #[error("HTTP client error: {0}")] + Http(#[from] reqwest::Error), + #[error("Timeout after {0:?}")] + Timeout(Duration), + #[error("Connection refused to {0}")] + ConnectionRefused(String), + #[error("Invalid axon: {0}")] + InvalidAxon(String), + #[error("Signing error: {0}")] + Signing(String), + #[error("Stream error: {0}")] + Stream(#[from] StreamError), +} + +/// Dendrite HTTP client for making requests to Axon servers +/// +/// The Dendrite handles all aspects of communicating with Axon servers: +/// - Building and signing requests +/// - Managing connection pooling +/// - Handling timeouts +/// - Parsing responses +/// +/// # Example +/// +/// ```ignore +/// use bittensor_rs::dendrite::Dendrite; +/// +/// // Create a dendrite without signing (anonymous) +/// let dendrite = Dendrite::new(None); +/// +/// // Or with a keypair for signed requests +/// let keypair = sr25519::Pair::from_string("//Alice", None)?; +/// let dendrite = Dendrite::new(Some(keypair)); +/// +/// // Make a request +/// let synapse = Synapse::new().with_name("Query"); +/// let response = dendrite.call(&axon, synapse).await?; +/// ``` +pub struct Dendrite { + /// The HTTP client (with connection pooling) + client: Client, + /// Optional keypair for signing requests + keypair: Option, + /// Default timeout for requests + timeout: Duration, + /// Dendrite version + version: u64, + /// Dendrite IP (optional, for headers) + ip: Option, + /// Dendrite port (optional, for headers) + port: Option, +} + +impl Dendrite { + /// Create a new Dendrite client + /// + /// # Arguments + /// + /// * `keypair` - Optional SR25519 keypair for signing requests + /// + /// # Returns + /// + /// A new Dendrite instance with default settings + pub fn new(keypair: Option) -> Self { + let client = Client::builder() + .pool_max_idle_per_host(10) + .pool_idle_timeout(Duration::from_secs(90)) + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) + .build() + .expect("Failed to build HTTP client"); + + Self { + client, + keypair, + timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS), + version: DEFAULT_DENDRITE_VERSION, + ip: None, + port: None, + } + } + + /// Set the default timeout for requests + /// + /// # Arguments + /// + /// * `timeout` - The timeout duration + /// + /// # Returns + /// + /// Self for method chaining + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + /// Set the Dendrite version + /// + /// # Arguments + /// + /// * `version` - The version number + /// + /// # Returns + /// + /// Self for method chaining + pub fn with_version(mut self, version: u64) -> Self { + self.version = version; + self + } + + /// Set the Dendrite IP address for headers + /// + /// # Arguments + /// + /// * `ip` - The IP address string + /// + /// # Returns + /// + /// Self for method chaining + pub fn with_ip(mut self, ip: impl Into) -> Self { + self.ip = Some(ip.into()); + self + } + + /// Set the Dendrite port for headers + /// + /// # Arguments + /// + /// * `port` - The port number + /// + /// # Returns + /// + /// Self for method chaining + pub fn with_port(mut self, port: u16) -> Self { + self.port = Some(port); + self + } + + /// Get the hotkey SS58 address if a keypair is set + pub fn hotkey(&self) -> Option { + self.keypair.as_ref().map(|kp| kp.public().to_ss58()) + } + + /// Build the dendrite terminal info for request headers + fn build_dendrite_info(&self) -> TerminalInfo { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos() as u64) + .unwrap_or(0); + + TerminalInfo { + ip: self.ip.clone(), + port: self.port, + version: Some(self.version), + nonce: Some(nonce), + uuid: Some(Uuid::new_v4().to_string()), + hotkey: self.hotkey(), + ..Default::default() + } + } + + /// Send a synapse to a single axon + /// + /// # Arguments + /// + /// * `axon` - The target Axon server + /// * `synapse` - The Synapse to send + /// + /// # Returns + /// + /// The response Synapse with updated terminal info and response data + pub async fn call(&self, axon: &AxonInfo, synapse: Synapse) -> Result { + self.call_with_timeout(axon, synapse, self.timeout).await + } + + /// Send a synapse to a single axon with a specific timeout + /// + /// # Arguments + /// + /// * `axon` - The target Axon server + /// * `synapse` - The Synapse to send + /// * `timeout` - Request timeout + /// + /// # Returns + /// + /// The response Synapse with updated terminal info and response data + pub async fn call_with_timeout( + &self, + axon: &AxonInfo, + synapse: Synapse, + timeout: Duration, + ) -> Result { + // Validate axon is serving + if !axon.is_serving() { + return Err(DendriteError::InvalidAxon( + "Axon is not serving (0.0.0.0)".to_string(), + )); + } + + let start_time = Instant::now(); + + // Build the request + let dendrite_info = self.build_dendrite_info(); + let mut request = DendriteRequest::new(axon, &synapse, &dendrite_info, timeout)?; + + // Sign the request if we have a keypair + // For signing, we need the axon's hotkey - we'll use the IP:port as identifier for unsigned + let axon_hotkey = axon.ip_str(); + if let Some(ref keypair) = self.keypair { + request.sign(keypair, &axon_hotkey)?; + } + + // Convert to HTTP headers + let headers = crate::dendrite::request::synapse_to_headers(&Synapse { + name: synapse.name.clone(), + timeout: Some(timeout.as_secs_f64()), + dendrite: Some(TerminalInfo { + ip: request.headers.dendrite_ip.clone(), + port: request + .headers + .dendrite_port + .as_ref() + .and_then(|p| p.parse().ok()), + version: request + .headers + .dendrite_version + .as_ref() + .and_then(|v| v.parse().ok()), + nonce: request + .headers + .dendrite_nonce + .as_ref() + .and_then(|n| n.parse().ok()), + uuid: request.headers.dendrite_uuid.clone(), + hotkey: request.headers.dendrite_hotkey.clone(), + signature: request.headers.dendrite_signature.clone(), + ..Default::default() + }), + computed_body_hash: request.headers.computed_body_hash.clone(), + ..Default::default() + }); + + // Build the HTTP request + let http_request = self + .client + .post(&request.url) + .headers(headers) + .body(request.body) + .timeout(timeout); + + // Execute the request + let result = http_request.send().await; + let process_time = start_time.elapsed().as_secs_f64(); + + match result { + Ok(response) => { + let status = response.status().as_u16(); + let response_headers = response.headers().clone(); + let body = response.bytes().await?.to_vec(); + + let dendrite_response = + DendriteResponse::new(status, response_headers, body, process_time); + Ok(dendrite_response.into_synapse()?) + } + Err(e) => { + if e.is_timeout() { + Ok(build_error_synapse( + &synapse, + status_codes::TIMEOUT, + "Request timeout", + process_time, + )) + } else if e.is_connect() { + Ok(build_error_synapse( + &synapse, + status_codes::SERVICE_UNAVAILABLE, + &format!("Connection failed: {}", e), + process_time, + )) + } else { + Err(DendriteError::Http(e)) + } + } + } + } + + /// Send a synapse to multiple axons concurrently + /// + /// # Arguments + /// + /// * `axons` - List of target Axon servers + /// * `synapse` - The Synapse to send (cloned for each request) + /// + /// # Returns + /// + /// A vector of results, one for each axon in the same order + pub async fn call_many( + &self, + axons: &[AxonInfo], + synapse: Synapse, + ) -> Vec> { + self.forward(axons, synapse, None).await + } + + /// Forward a synapse to multiple axons (like Python dendrite.forward) + /// + /// This is the main method for sending requests to multiple axons, + /// equivalent to the Python SDK's `dendrite.forward()` method. + /// + /// # Arguments + /// + /// * `axons` - List of target Axon servers + /// * `synapse` - The Synapse to send (cloned for each request) + /// * `timeout` - Optional timeout override + /// + /// # Returns + /// + /// A vector of results, one for each axon in the same order + pub async fn forward( + &self, + axons: &[AxonInfo], + synapse: Synapse, + timeout: Option, + ) -> Vec> { + let timeout = timeout.unwrap_or(self.timeout); + + // Create futures for all requests + let futures: Vec<_> = axons + .iter() + .map(|axon| { + let synapse_clone = synapse.clone(); + self.call_with_timeout(axon, synapse_clone, timeout) + }) + .collect(); + + // Execute all concurrently + futures::future::join_all(futures).await + } + + /// Send a streaming synapse to a single axon + /// + /// # Arguments + /// + /// * `axon` - The target Axon server + /// * `synapse` - The streaming synapse to send + /// + /// # Returns + /// + /// A Stream that yields chunks as they arrive + pub async fn call_stream( + &self, + axon: &AxonInfo, + synapse: S, + ) -> Result>, DendriteError> + where + S: StreamingSynapse + Unpin + 'static, + { + self.call_stream_with_timeout(axon, synapse, self.timeout) + .await + } + + /// Send a streaming synapse to a single axon with a specific timeout + /// + /// # Arguments + /// + /// * `axon` - The target Axon server + /// * `synapse` - The streaming synapse to send + /// * `timeout` - Connection timeout (not stream timeout) + /// + /// # Returns + /// + /// A Stream that yields chunks as they arrive + pub async fn call_stream_with_timeout( + &self, + axon: &AxonInfo, + synapse: S, + timeout: Duration, + ) -> Result>, DendriteError> + where + S: StreamingSynapse + Unpin + 'static, + { + // Validate axon is serving + if !axon.is_serving() { + return Err(DendriteError::InvalidAxon( + "Axon is not serving (0.0.0.0)".to_string(), + )); + } + + // Build the endpoint URL + let url = format!("{}/{}", axon.to_endpoint(), synapse.name()); + + // Build headers + let dendrite_info = self.build_dendrite_info(); + let mut headers = http::HeaderMap::new(); + + // Add dendrite headers + if let Some(ref ip) = dendrite_info.ip { + if let Ok(hv) = http::HeaderValue::from_str(ip) { + headers.insert("bt_header_dendrite_ip", hv); + } + } + if let Some(port) = dendrite_info.port { + if let Ok(hv) = http::HeaderValue::from_str(&port.to_string()) { + headers.insert("bt_header_dendrite_port", hv); + } + } + if let Some(version) = dendrite_info.version { + if let Ok(hv) = http::HeaderValue::from_str(&version.to_string()) { + headers.insert("bt_header_dendrite_version", hv); + } + } + if let Some(nonce) = dendrite_info.nonce { + if let Ok(hv) = http::HeaderValue::from_str(&nonce.to_string()) { + headers.insert("bt_header_dendrite_nonce", hv); + } + } + if let Some(ref uuid) = dendrite_info.uuid { + if let Ok(hv) = http::HeaderValue::from_str(uuid) { + headers.insert("bt_header_dendrite_uuid", hv); + } + } + if let Some(ref hotkey) = dendrite_info.hotkey { + if let Ok(hv) = http::HeaderValue::from_str(hotkey) { + headers.insert("bt_header_dendrite_hotkey", hv); + } + } + + // Add name header + if let Ok(hv) = http::HeaderValue::from_str(synapse.name()) { + headers.insert("name", hv); + } + + // Add timeout header + if let Ok(hv) = http::HeaderValue::from_str(&timeout.as_secs_f64().to_string()) { + headers.insert("bt_header_timeout", hv); + } + + // Build the HTTP request + let http_request = self + .client + .post(&url) + .headers(headers) + .timeout(timeout); + + // Execute the request and get the response stream + let response = http_request.send().await?; + + if !response.status().is_success() { + return Err(DendriteError::Response(ResponseError::HttpError { + status: response.status().as_u16(), + })); + } + + // Get the byte stream from the response + let byte_stream = response.bytes_stream(); + + // Create the streaming response + Ok(StreamingResponse::new(synapse, byte_stream)) + } +} + +impl Clone for Dendrite { + fn clone(&self) -> Self { + Self { + client: self.client.clone(), + keypair: self.keypair.clone(), + timeout: self.timeout, + version: self.version, + ip: self.ip.clone(), + port: self.port, + } + } +} + +impl Default for Dendrite { + fn default() -> Self { + Self::new(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::IpAddr; + + fn create_test_axon() -> AxonInfo { + AxonInfo { + block: 1000, + version: 100, + ip: IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + port: 8091, + ip_type: 4, + protocol: 0, + placeholder1: 0, + placeholder2: 0, + } + } + + fn create_non_serving_axon() -> AxonInfo { + AxonInfo { + block: 1000, + version: 100, + ip: IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), + port: 8091, + ip_type: 4, + protocol: 0, + placeholder1: 0, + placeholder2: 0, + } + } + + #[test] + fn test_dendrite_new_without_keypair() { + let dendrite = Dendrite::new(None); + assert!(dendrite.keypair.is_none()); + assert!(dendrite.hotkey().is_none()); + assert_eq!(dendrite.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS)); + } + + #[test] + fn test_dendrite_new_with_keypair() { + let keypair = + sr25519::Pair::from_string("//Alice", None).expect("Failed to create test keypair"); + let dendrite = Dendrite::new(Some(keypair)); + + assert!(dendrite.keypair.is_some()); + assert!(dendrite.hotkey().is_some()); + } + + #[test] + fn test_dendrite_with_timeout() { + let dendrite = Dendrite::new(None).with_timeout(Duration::from_secs(30)); + assert_eq!(dendrite.timeout, Duration::from_secs(30)); + } + + #[test] + fn test_dendrite_with_version() { + let dendrite = Dendrite::new(None).with_version(200); + assert_eq!(dendrite.version, 200); + } + + #[test] + fn test_dendrite_with_ip_and_port() { + let dendrite = Dendrite::new(None) + .with_ip("192.168.1.1") + .with_port(8080); + + assert_eq!(dendrite.ip, Some("192.168.1.1".to_string())); + assert_eq!(dendrite.port, Some(8080)); + } + + #[test] + fn test_build_dendrite_info() { + let keypair = + sr25519::Pair::from_string("//Alice", None).expect("Failed to create test keypair"); + let dendrite = Dendrite::new(Some(keypair)) + .with_ip("10.0.0.1") + .with_port(9000) + .with_version(150); + + let info = dendrite.build_dendrite_info(); + + assert_eq!(info.ip, Some("10.0.0.1".to_string())); + assert_eq!(info.port, Some(9000)); + assert_eq!(info.version, Some(150)); + assert!(info.nonce.is_some()); + assert!(info.uuid.is_some()); + assert!(info.hotkey.is_some()); + } + + #[test] + fn test_dendrite_clone() { + let dendrite = Dendrite::new(None) + .with_timeout(Duration::from_secs(20)) + .with_version(300); + + let cloned = dendrite.clone(); + + assert_eq!(cloned.timeout, Duration::from_secs(20)); + assert_eq!(cloned.version, 300); + } + + #[test] + fn test_dendrite_default() { + let dendrite = Dendrite::default(); + + assert!(dendrite.keypair.is_none()); + assert_eq!(dendrite.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS)); + assert_eq!(dendrite.version, DEFAULT_DENDRITE_VERSION); + } + + #[tokio::test] + async fn test_call_non_serving_axon() { + let dendrite = Dendrite::new(None); + let axon = create_non_serving_axon(); + let synapse = Synapse::new().with_name("Test"); + + let result = dendrite.call(&axon, synapse).await; + + assert!(result.is_err()); + match result { + Err(DendriteError::InvalidAxon(_)) => {} + _ => panic!("Expected InvalidAxon error"), + } + } + + #[tokio::test] + async fn test_call_many_empty() { + let dendrite = Dendrite::new(None); + let axons: Vec = vec![]; + let synapse = Synapse::new().with_name("Test"); + + let results = dendrite.call_many(&axons, synapse).await; + + assert!(results.is_empty()); + } + + #[tokio::test] + async fn test_forward_with_timeout() { + let dendrite = Dendrite::new(None); + let axon = create_test_axon(); + let synapse = Synapse::new().with_name("Test"); + + // This will fail to connect but should respect the timeout + let results = dendrite + .forward(&[axon], synapse, Some(Duration::from_millis(100))) + .await; + + assert_eq!(results.len(), 1); + // The result should be an error synapse (connection failed) or an error + // We're just testing the API works correctly + } +} diff --git a/src/dendrite/mod.rs b/src/dendrite/mod.rs new file mode 100644 index 0000000..edafb8e --- /dev/null +++ b/src/dendrite/mod.rs @@ -0,0 +1,28 @@ +//! Dendrite HTTP client module for Bittensor network communication +//! +//! The Dendrite is responsible for making HTTP requests to Axon servers +//! in the Bittensor network. It handles request signing, response parsing, +//! and supports both standard and streaming communication patterns. +//! +//! # Example +//! +//! ```ignore +//! use bittensor_rs::dendrite::Dendrite; +//! use bittensor_rs::types::{AxonInfo, Synapse}; +//! +//! let dendrite = Dendrite::new(None); +//! let axon = // ... get axon info from metagraph +//! let synapse = Synapse::new().with_name("MyQuery"); +//! +//! let response = dendrite.call(&axon, synapse).await?; +//! ``` + +pub mod client; +pub mod request; +pub mod response; +pub mod streaming; + +pub use client::Dendrite; +pub use request::{headers_to_synapse, synapse_to_headers, DendriteRequest}; +pub use response::DendriteResponse; +pub use streaming::{StreamingResponse, StreamingSynapse}; diff --git a/src/dendrite/request.rs b/src/dendrite/request.rs new file mode 100644 index 0000000..c0cdf5b --- /dev/null +++ b/src/dendrite/request.rs @@ -0,0 +1,486 @@ +//! Request building and signing for Dendrite HTTP requests +//! +//! This module handles the construction of HTTP requests to Axon servers, +//! including header generation, body hashing, and cryptographic signing. + +use crate::types::{AxonInfo, Synapse, SynapseHeaders, TerminalInfo}; +use http::header::HeaderMap; +use sha2::{Digest, Sha256}; +use sp_core::{sr25519, Pair}; +use std::time::Duration; +use thiserror::Error; + +/// Errors that can occur during request building +#[derive(Debug, Error)] +pub enum RequestError { + #[error("Serialization error: {0}")] + Serialization(String), + #[error("Invalid URL: {0}")] + InvalidUrl(String), + #[error("Signing error: {0}")] + Signing(String), + #[error("Invalid header value: {0}")] + InvalidHeader(String), +} + +/// A prepared Dendrite request ready for transmission +#[derive(Debug, Clone)] +pub struct DendriteRequest { + /// The target URL for the request + pub url: String, + /// Synapse headers for the request + pub headers: SynapseHeaders, + /// Serialized request body + pub body: Vec, + /// Request timeout + pub timeout: Duration, +} + +impl DendriteRequest { + /// Create a new DendriteRequest from an AxonInfo and Synapse + /// + /// # Arguments + /// + /// * `axon` - The target Axon server information + /// * `synapse` - The Synapse to send + /// * `dendrite_info` - Terminal info for the dendrite (sender) + /// * `timeout` - Request timeout duration + /// + /// # Returns + /// + /// A prepared request or an error + pub fn new( + axon: &AxonInfo, + synapse: &Synapse, + dendrite_info: &TerminalInfo, + timeout: Duration, + ) -> Result { + // Build the endpoint URL + let synapse_name = synapse.name.as_deref().unwrap_or("Synapse"); + let url = format!("{}/{}", axon.to_endpoint(), synapse_name); + + // Serialize the synapse body (just the extra fields, not the headers) + let body = serde_json::to_vec(&synapse.extra) + .map_err(|e| RequestError::Serialization(e.to_string()))?; + + // Build initial headers from synapse + let mut headers = synapse.to_headers(); + + // Add dendrite terminal info + headers.dendrite_ip = dendrite_info.ip.clone(); + headers.dendrite_port = dendrite_info.port.map(|p| p.to_string()); + headers.dendrite_version = dendrite_info.version.map(|v| v.to_string()); + headers.dendrite_nonce = dendrite_info.nonce.map(|n| n.to_string()); + headers.dendrite_uuid = dendrite_info.uuid.clone(); + headers.dendrite_hotkey = dendrite_info.hotkey.clone(); + + // Set timeout in headers + headers.timeout = Some(timeout.as_secs_f64().to_string()); + + Ok(Self { + url, + headers, + body, + timeout, + }) + } + + /// Compute the SHA-256 hash of the request body + /// + /// # Returns + /// + /// Hexadecimal string of the body hash + pub fn compute_body_hash(&self) -> String { + let mut hasher = Sha256::new(); + hasher.update(&self.body); + let result = hasher.finalize(); + hex::encode(result) + } + + /// Sign the request with the given keypair + /// + /// The signature format matches the Python SDK: + /// `sign(message = "{nonce}.{dendrite_hotkey}.{axon_hotkey}.{body_hash}")` + /// + /// # Arguments + /// + /// * `keypair` - The SR25519 keypair to sign with + /// * `axon_hotkey` - The axon's hotkey SS58 address + /// + /// # Returns + /// + /// Ok(()) if signing succeeds, otherwise an error + pub fn sign(&mut self, keypair: &sr25519::Pair, axon_hotkey: &str) -> Result<(), RequestError> { + // Compute body hash + let body_hash = self.compute_body_hash(); + self.headers.computed_body_hash = Some(body_hash.clone()); + + // Get nonce from headers + let nonce = self + .headers + .dendrite_nonce + .as_ref() + .ok_or_else(|| RequestError::Signing("Missing nonce".to_string()))?; + + // Get dendrite hotkey from headers + let dendrite_hotkey = self + .headers + .dendrite_hotkey + .as_ref() + .ok_or_else(|| RequestError::Signing("Missing dendrite hotkey".to_string()))?; + + // Create message to sign: "{nonce}.{dendrite_hotkey}.{axon_hotkey}.{body_hash}" + let message = format!("{}.{}.{}.{}", nonce, dendrite_hotkey, axon_hotkey, body_hash); + + // Sign the message + let signature = keypair.sign(message.as_bytes()); + self.headers.dendrite_signature = Some(hex::encode(signature.0)); + + Ok(()) + } +} + +/// Header name constants matching Python SDK +pub mod header_names { + pub const DENDRITE_IP: &str = "bt_header_dendrite_ip"; + pub const DENDRITE_PORT: &str = "bt_header_dendrite_port"; + pub const DENDRITE_VERSION: &str = "bt_header_dendrite_version"; + pub const DENDRITE_NONCE: &str = "bt_header_dendrite_nonce"; + pub const DENDRITE_UUID: &str = "bt_header_dendrite_uuid"; + pub const DENDRITE_HOTKEY: &str = "bt_header_dendrite_hotkey"; + pub const DENDRITE_SIGNATURE: &str = "bt_header_dendrite_signature"; + pub const AXON_IP: &str = "bt_header_axon_ip"; + pub const AXON_PORT: &str = "bt_header_axon_port"; + pub const AXON_VERSION: &str = "bt_header_axon_version"; + pub const AXON_NONCE: &str = "bt_header_axon_nonce"; + pub const AXON_UUID: &str = "bt_header_axon_uuid"; + pub const AXON_HOTKEY: &str = "bt_header_axon_hotkey"; + pub const AXON_SIGNATURE: &str = "bt_header_axon_signature"; + pub const AXON_STATUS_CODE: &str = "bt_header_axon_status_code"; + pub const AXON_STATUS_MESSAGE: &str = "bt_header_axon_status_message"; + pub const AXON_PROCESS_TIME: &str = "bt_header_axon_process_time"; + pub const INPUT_OBJ: &str = "bt_header_input_obj"; + pub const OUTPUT_OBJ: &str = "bt_header_output_obj"; + pub const TIMEOUT: &str = "bt_header_timeout"; + pub const BODY_HASH: &str = "computed_body_hash"; + pub const NAME: &str = "name"; + pub const TOTAL_SIZE: &str = "total_size"; + pub const HEADER_SIZE: &str = "header_size"; +} + +/// Convert a Synapse to HTTP headers for transmission +/// +/// # Arguments +/// +/// * `synapse` - The Synapse to convert +/// +/// # Returns +/// +/// An HTTP HeaderMap containing all synapse fields as headers +pub fn synapse_to_headers(synapse: &Synapse) -> HeaderMap { + let mut headers = HeaderMap::new(); + let synapse_headers = synapse.to_headers(); + + // Helper macro to add optional header values + macro_rules! add_header { + ($name:expr, $value:expr) => { + if let Some(ref v) = $value { + if let Ok(hv) = http::header::HeaderValue::from_str(v) { + headers.insert($name, hv); + } + } + }; + } + + // Synapse metadata + add_header!(header_names::NAME, synapse_headers.name); + add_header!(header_names::TIMEOUT, synapse_headers.timeout); + add_header!(header_names::TOTAL_SIZE, synapse_headers.total_size); + add_header!(header_names::HEADER_SIZE, synapse_headers.header_size); + add_header!(header_names::BODY_HASH, synapse_headers.computed_body_hash); + + // Dendrite terminal info + add_header!(header_names::DENDRITE_IP, synapse_headers.dendrite_ip); + add_header!(header_names::DENDRITE_PORT, synapse_headers.dendrite_port); + add_header!( + header_names::DENDRITE_VERSION, + synapse_headers.dendrite_version + ); + add_header!( + header_names::DENDRITE_NONCE, + synapse_headers.dendrite_nonce + ); + add_header!(header_names::DENDRITE_UUID, synapse_headers.dendrite_uuid); + add_header!( + header_names::DENDRITE_HOTKEY, + synapse_headers.dendrite_hotkey + ); + add_header!( + header_names::DENDRITE_SIGNATURE, + synapse_headers.dendrite_signature + ); + + // Axon terminal info + add_header!(header_names::AXON_IP, synapse_headers.axon_ip); + add_header!(header_names::AXON_PORT, synapse_headers.axon_port); + add_header!(header_names::AXON_VERSION, synapse_headers.axon_version); + add_header!(header_names::AXON_NONCE, synapse_headers.axon_nonce); + add_header!(header_names::AXON_UUID, synapse_headers.axon_uuid); + add_header!(header_names::AXON_HOTKEY, synapse_headers.axon_hotkey); + add_header!( + header_names::AXON_SIGNATURE, + synapse_headers.axon_signature + ); + + headers +} + +/// Parse HTTP headers into a Synapse +/// +/// # Arguments +/// +/// * `headers` - The HTTP response headers +/// * `body` - The response body bytes +/// +/// # Returns +/// +/// A reconstructed Synapse with data from headers and body +pub fn headers_to_synapse(headers: &HeaderMap, body: &[u8]) -> Result { + // Helper to get header value as string + fn get_header(headers: &HeaderMap, name: &str) -> Option { + headers + .get(name) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) + } + + // Helper to parse header as number + fn get_header_u64(headers: &HeaderMap, name: &str) -> Option { + get_header(headers, name).and_then(|s| s.parse().ok()) + } + + fn get_header_f64(headers: &HeaderMap, name: &str) -> Option { + get_header(headers, name).and_then(|s| s.parse().ok()) + } + + fn get_header_i32(headers: &HeaderMap, name: &str) -> Option { + get_header(headers, name).and_then(|s| s.parse().ok()) + } + + fn get_header_u16(headers: &HeaderMap, name: &str) -> Option { + get_header(headers, name).and_then(|s| s.parse().ok()) + } + + // Build dendrite terminal info from headers + let dendrite = TerminalInfo { + status_code: get_header_i32(headers, "bt_header_dendrite_status_code"), + status_message: get_header(headers, "bt_header_dendrite_status_message"), + process_time: get_header_f64(headers, "bt_header_dendrite_process_time"), + ip: get_header(headers, header_names::DENDRITE_IP), + port: get_header_u16(headers, header_names::DENDRITE_PORT), + version: get_header_u64(headers, header_names::DENDRITE_VERSION), + nonce: get_header_u64(headers, header_names::DENDRITE_NONCE), + uuid: get_header(headers, header_names::DENDRITE_UUID), + hotkey: get_header(headers, header_names::DENDRITE_HOTKEY), + signature: get_header(headers, header_names::DENDRITE_SIGNATURE), + }; + + // Build axon terminal info from headers + let axon = TerminalInfo { + status_code: get_header_i32(headers, header_names::AXON_STATUS_CODE), + status_message: get_header(headers, header_names::AXON_STATUS_MESSAGE), + process_time: get_header_f64(headers, header_names::AXON_PROCESS_TIME), + ip: get_header(headers, header_names::AXON_IP), + port: get_header_u16(headers, header_names::AXON_PORT), + version: get_header_u64(headers, header_names::AXON_VERSION), + nonce: get_header_u64(headers, header_names::AXON_NONCE), + uuid: get_header(headers, header_names::AXON_UUID), + hotkey: get_header(headers, header_names::AXON_HOTKEY), + signature: get_header(headers, header_names::AXON_SIGNATURE), + }; + + // Parse body as JSON extra fields + let extra = if body.is_empty() { + std::collections::HashMap::new() + } else { + serde_json::from_slice(body).map_err(|e| RequestError::Serialization(e.to_string()))? + }; + + Ok(Synapse { + name: get_header(headers, header_names::NAME), + timeout: get_header_f64(headers, header_names::TIMEOUT), + total_size: get_header_u64(headers, header_names::TOTAL_SIZE), + header_size: get_header_u64(headers, header_names::HEADER_SIZE), + dendrite: Some(dendrite), + axon: Some(axon), + computed_body_hash: get_header(headers, header_names::BODY_HASH), + extra, + }) +} + +/// Create a signature message for request authentication +/// +/// The signature format matches the Python SDK: +/// `"{nonce}.{dendrite_hotkey}.{axon_hotkey}.{body_hash}"` +/// +/// # Arguments +/// +/// * `nonce` - Request nonce +/// * `dendrite_hotkey` - The sender's hotkey SS58 address +/// * `axon_hotkey` - The target axon's hotkey SS58 address +/// * `body_hash` - SHA-256 hash of the request body +/// +/// # Returns +/// +/// The signature message string +pub fn create_signature_message( + nonce: u64, + dendrite_hotkey: &str, + axon_hotkey: &str, + body_hash: &str, +) -> String { + format!("{}.{}.{}.{}", nonce, dendrite_hotkey, axon_hotkey, body_hash) +} + +/// Sign a message with the given keypair and return hex-encoded signature +/// +/// # Arguments +/// +/// * `keypair` - The SR25519 keypair to sign with +/// * `message` - The message bytes to sign +/// +/// # Returns +/// +/// Hex-encoded signature string +pub fn sign_message(keypair: &sr25519::Pair, message: &[u8]) -> String { + let signature = keypair.sign(message); + hex::encode(signature.0) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::IpAddr; + + fn create_test_axon() -> AxonInfo { + AxonInfo { + block: 1000, + version: 100, + ip: IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + port: 8091, + ip_type: 4, + protocol: 0, + placeholder1: 0, + placeholder2: 0, + } + } + + fn create_test_synapse() -> Synapse { + Synapse::new().with_name("TestSynapse").with_timeout(12.0) + } + + fn create_test_dendrite_info() -> TerminalInfo { + TerminalInfo { + ip: Some("192.168.1.1".to_string()), + port: Some(8080), + version: Some(100), + nonce: Some(12345678), + uuid: Some("test-uuid-1234".to_string()), + hotkey: Some("5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY".to_string()), + ..Default::default() + } + } + + #[test] + fn test_dendrite_request_new() { + let axon = create_test_axon(); + let synapse = create_test_synapse(); + let dendrite_info = create_test_dendrite_info(); + + let request = + DendriteRequest::new(&axon, &synapse, &dendrite_info, Duration::from_secs(12)).unwrap(); + + assert_eq!(request.url, "http://127.0.0.1:8091/TestSynapse"); + assert_eq!( + request.headers.dendrite_hotkey.as_deref(), + Some("5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY") + ); + assert_eq!(request.headers.dendrite_nonce.as_deref(), Some("12345678")); + } + + #[test] + fn test_compute_body_hash() { + let axon = create_test_axon(); + let synapse = create_test_synapse(); + let dendrite_info = create_test_dendrite_info(); + + let request = + DendriteRequest::new(&axon, &synapse, &dendrite_info, Duration::from_secs(12)).unwrap(); + + let hash = request.compute_body_hash(); + // SHA-256 hash should be 64 hex characters + assert_eq!(hash.len(), 64); + // Hash should be deterministic + assert_eq!(hash, request.compute_body_hash()); + } + + #[test] + fn test_create_signature_message() { + let message = create_signature_message( + 12345, + "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY", + "5FHneW46xGXgs5mUiveU4sbTyGBzmstUspZC92UhjJM694ty", + "abc123def456", + ); + + assert_eq!( + message, + "12345.5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY.5FHneW46xGXgs5mUiveU4sbTyGBzmstUspZC92UhjJM694ty.abc123def456" + ); + } + + #[test] + fn test_synapse_to_headers_and_back() { + let mut synapse = Synapse::new().with_name("TestSynapse").with_timeout(15.0); + + synapse.dendrite = Some(TerminalInfo { + ip: Some("10.0.0.1".to_string()), + port: Some(8080), + version: Some(200), + nonce: Some(99999), + ..Default::default() + }); + + let headers = synapse_to_headers(&synapse); + + // Verify some headers exist + assert!(headers.contains_key(header_names::NAME)); + assert!(headers.contains_key(header_names::TIMEOUT)); + } + + #[test] + fn test_headers_to_synapse_empty_body() { + let mut headers = HeaderMap::new(); + headers.insert(header_names::NAME, "ParsedSynapse".parse().unwrap()); + headers.insert(header_names::TIMEOUT, "20.0".parse().unwrap()); + + let synapse = headers_to_synapse(&headers, &[]).unwrap(); + + assert_eq!(synapse.name, Some("ParsedSynapse".to_string())); + assert_eq!(synapse.timeout, Some(20.0)); + assert!(synapse.extra.is_empty()); + } + + #[test] + fn test_sign_message() { + // Create a test keypair from seed + let keypair = + sr25519::Pair::from_string("//Alice", None).expect("Failed to create test keypair"); + + let message = b"test message"; + let signature = sign_message(&keypair, message); + + // Signature should be 128 hex characters (64 bytes) + assert_eq!(signature.len(), 128); + } +} diff --git a/src/dendrite/response.rs b/src/dendrite/response.rs new file mode 100644 index 0000000..5eb8528 --- /dev/null +++ b/src/dendrite/response.rs @@ -0,0 +1,304 @@ +//! Response handling for Dendrite HTTP responses +//! +//! This module provides types and utilities for processing responses +//! from Axon servers, including status interpretation and synapse reconstruction. + +use crate::dendrite::request::{headers_to_synapse, RequestError}; +use crate::types::{Synapse, TerminalInfo}; +use http::header::HeaderMap; +use thiserror::Error; + +/// Errors that can occur during response processing +#[derive(Debug, Error)] +pub enum ResponseError { + #[error("HTTP error: status {status}")] + HttpError { status: u16 }, + #[error("Timeout")] + Timeout, + #[error("Deserialization error: {0}")] + Deserialization(String), + #[error("Request error: {0}")] + Request(#[from] RequestError), + #[error("Network error: {0}")] + Network(String), +} + +/// A response from a Dendrite HTTP request +#[derive(Debug, Clone)] +pub struct DendriteResponse { + /// HTTP status code + pub status: u16, + /// Response headers + pub headers: HeaderMap, + /// Response body + pub body: Vec, + /// Time taken to process the request (in seconds) + pub process_time: f64, +} + +impl DendriteResponse { + /// Create a new DendriteResponse + /// + /// # Arguments + /// + /// * `status` - HTTP status code + /// * `headers` - Response headers + /// * `body` - Response body bytes + /// * `process_time` - Processing time in seconds + pub fn new(status: u16, headers: HeaderMap, body: Vec, process_time: f64) -> Self { + Self { + status, + headers, + body, + process_time, + } + } + + /// Convert the response into a Synapse + /// + /// Parses headers and body to reconstruct the full synapse with + /// terminal information and response data. + /// + /// # Returns + /// + /// The reconstructed Synapse or an error + pub fn into_synapse(self) -> Result { + let mut synapse = headers_to_synapse(&self.headers, &self.body)?; + + // Update the dendrite terminal info with response status + if let Some(ref mut dendrite) = synapse.dendrite { + dendrite.status_code = Some(self.status as i32); + dendrite.process_time = Some(self.process_time); + } else { + synapse.dendrite = Some(TerminalInfo { + status_code: Some(self.status as i32), + process_time: Some(self.process_time), + ..Default::default() + }); + } + + Ok(synapse) + } + + /// Check if the response indicates success (2xx status code) + pub fn is_success(&self) -> bool { + (200..300).contains(&self.status) + } + + /// Check if the response indicates a timeout (408 or 504) + pub fn is_timeout(&self) -> bool { + self.status == 408 || self.status == 504 + } + + /// Check if the response indicates a client error (4xx) + pub fn is_client_error(&self) -> bool { + (400..500).contains(&self.status) + } + + /// Check if the response indicates a server error (5xx) + pub fn is_server_error(&self) -> bool { + (500..600).contains(&self.status) + } + + /// Get the axon status code from headers if present + pub fn axon_status_code(&self) -> Option { + self.headers + .get("bt_header_axon_status_code") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()) + } + + /// Get the axon status message from headers if present + pub fn axon_status_message(&self) -> Option { + self.headers + .get("bt_header_axon_status_message") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) + } + + /// Get the axon process time from headers if present + pub fn axon_process_time(&self) -> Option { + self.headers + .get("bt_header_axon_process_time") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()) + } + + /// Get the body as a string (if valid UTF-8) + pub fn body_as_str(&self) -> Option<&str> { + std::str::from_utf8(&self.body).ok() + } + + /// Deserialize the body as JSON into a specific type + pub fn json(&self) -> Result { + serde_json::from_slice(&self.body) + .map_err(|e| ResponseError::Deserialization(e.to_string())) + } +} + +/// Build a failed response synapse for error conditions +/// +/// This is used when the request fails before reaching the axon, +/// such as connection errors or timeouts. +/// +/// # Arguments +/// +/// * `original` - The original synapse that was sent +/// * `status_code` - Error status code +/// * `message` - Error message +/// * `process_time` - Time taken before failure +/// +/// # Returns +/// +/// A synapse marked with failure status +pub fn build_error_synapse( + original: &Synapse, + status_code: i32, + message: &str, + process_time: f64, +) -> Synapse { + let mut synapse = original.clone(); + + // Update dendrite terminal info with error status + let dendrite = synapse.dendrite.get_or_insert_with(TerminalInfo::default); + dendrite.status_code = Some(status_code); + dendrite.status_message = Some(message.to_string()); + dendrite.process_time = Some(process_time); + + synapse +} + +/// Standard status codes used in Bittensor protocol +pub mod status_codes { + /// Request successful + pub const SUCCESS: i32 = 200; + /// Request successful, no content + pub const NO_CONTENT: i32 = 204; + /// Bad request (malformed) + pub const BAD_REQUEST: i32 = 400; + /// Unauthorized (invalid signature) + pub const UNAUTHORIZED: i32 = 401; + /// Forbidden (blacklisted) + pub const FORBIDDEN: i32 = 403; + /// Not found (endpoint doesn't exist) + pub const NOT_FOUND: i32 = 404; + /// Request timeout + pub const TIMEOUT: i32 = 408; + /// Too many requests (rate limited) + pub const TOO_MANY_REQUESTS: i32 = 429; + /// Internal server error + pub const INTERNAL_ERROR: i32 = 500; + /// Service unavailable + pub const SERVICE_UNAVAILABLE: i32 = 503; + /// Gateway timeout + pub const GATEWAY_TIMEOUT: i32 = 504; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dendrite_response_is_success() { + let response = DendriteResponse::new(200, HeaderMap::new(), vec![], 0.5); + assert!(response.is_success()); + assert!(!response.is_timeout()); + assert!(!response.is_client_error()); + assert!(!response.is_server_error()); + } + + #[test] + fn test_dendrite_response_is_timeout() { + let response_408 = DendriteResponse::new(408, HeaderMap::new(), vec![], 12.0); + assert!(response_408.is_timeout()); + assert!(!response_408.is_success()); + + let response_504 = DendriteResponse::new(504, HeaderMap::new(), vec![], 12.0); + assert!(response_504.is_timeout()); + } + + #[test] + fn test_dendrite_response_is_client_error() { + let response = DendriteResponse::new(404, HeaderMap::new(), vec![], 0.1); + assert!(response.is_client_error()); + assert!(!response.is_success()); + assert!(!response.is_server_error()); + } + + #[test] + fn test_dendrite_response_is_server_error() { + let response = DendriteResponse::new(500, HeaderMap::new(), vec![], 0.1); + assert!(response.is_server_error()); + assert!(!response.is_success()); + assert!(!response.is_client_error()); + } + + #[test] + fn test_dendrite_response_json() { + use serde::Deserialize; + + #[derive(Deserialize, Debug, PartialEq)] + struct TestData { + value: i32, + } + + let body = br#"{"value": 42}"#; + let response = DendriteResponse::new(200, HeaderMap::new(), body.to_vec(), 0.1); + + let data: TestData = response.json().unwrap(); + assert_eq!(data.value, 42); + } + + #[test] + fn test_dendrite_response_body_as_str() { + let body = b"Hello, world!"; + let response = DendriteResponse::new(200, HeaderMap::new(), body.to_vec(), 0.1); + + assert_eq!(response.body_as_str(), Some("Hello, world!")); + } + + #[test] + fn test_build_error_synapse() { + let original = Synapse::new().with_name("TestSynapse"); + let error_synapse = build_error_synapse(&original, 408, "Request timeout", 12.5); + + assert_eq!(error_synapse.name, Some("TestSynapse".to_string())); + + let dendrite = error_synapse.dendrite.unwrap(); + assert_eq!(dendrite.status_code, Some(408)); + assert_eq!(dendrite.status_message, Some("Request timeout".to_string())); + assert_eq!(dendrite.process_time, Some(12.5)); + } + + #[test] + fn test_into_synapse() { + let mut headers = HeaderMap::new(); + headers.insert("name", "ConvertedSynapse".parse().unwrap()); + headers.insert("bt_header_timeout", "15.0".parse().unwrap()); + headers.insert("bt_header_axon_status_code", "200".parse().unwrap()); + + let body = br#"{}"#; + let response = DendriteResponse::new(200, headers, body.to_vec(), 0.5); + + let synapse = response.into_synapse().unwrap(); + assert_eq!(synapse.name, Some("ConvertedSynapse".to_string())); + + let dendrite = synapse.dendrite.unwrap(); + assert_eq!(dendrite.status_code, Some(200)); + assert_eq!(dendrite.process_time, Some(0.5)); + } + + #[test] + fn test_axon_header_accessors() { + let mut headers = HeaderMap::new(); + headers.insert("bt_header_axon_status_code", "200".parse().unwrap()); + headers.insert("bt_header_axon_status_message", "OK".parse().unwrap()); + headers.insert("bt_header_axon_process_time", "0.123".parse().unwrap()); + + let response = DendriteResponse::new(200, headers, vec![], 0.2); + + assert_eq!(response.axon_status_code(), Some(200)); + assert_eq!(response.axon_status_message(), Some("OK".to_string())); + assert_eq!(response.axon_process_time(), Some(0.123)); + } +} diff --git a/src/dendrite/streaming.rs b/src/dendrite/streaming.rs new file mode 100644 index 0000000..bef8413 --- /dev/null +++ b/src/dendrite/streaming.rs @@ -0,0 +1,455 @@ +//! Streaming support for Dendrite requests +//! +//! This module provides types and utilities for streaming responses +//! from Axon servers, allowing for incremental processing of large +//! or continuous data streams. + +use futures::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Trait for synapses that support streaming responses +/// +/// Types implementing this trait can process response chunks incrementally +/// rather than waiting for the complete response. +pub trait StreamingSynapse: Send { + /// The type of each chunk produced by the stream + type Chunk: Send; + + /// Process a chunk of data from the response stream + /// + /// # Arguments + /// + /// * `chunk` - Raw bytes from the response stream + /// + /// # Returns + /// + /// `Some(chunk)` if a complete chunk was parsed, `None` if more data is needed + fn process_chunk(&mut self, chunk: &[u8]) -> Option; + + /// Check if the stream is complete + /// + /// Returns `true` when no more chunks are expected + fn is_complete(&self) -> bool; + + /// Get the name of this streaming synapse + fn name(&self) -> &str; + + /// Called when the stream ends (either normally or due to error) + /// + /// Default implementation does nothing + fn on_stream_end(&mut self) {} +} + +/// A streaming response wrapper that implements the Stream trait +/// +/// This wraps an async bytes stream and a StreamingSynapse to produce +/// parsed chunks as they arrive. +pub struct StreamingResponse +where + S: StreamingSynapse, + B: Stream> + Unpin, +{ + /// The synapse processor + synapse: S, + /// The underlying byte stream + byte_stream: B, + /// Buffer for incomplete chunks + buffer: Vec, + /// Whether the stream has completed + completed: bool, +} + +impl StreamingResponse +where + S: StreamingSynapse, + B: Stream> + Unpin, +{ + /// Create a new StreamingResponse + /// + /// # Arguments + /// + /// * `synapse` - The streaming synapse processor + /// * `byte_stream` - The underlying byte stream from the HTTP response + pub fn new(synapse: S, byte_stream: B) -> Self { + Self { + synapse, + byte_stream, + buffer: Vec::with_capacity(4096), + completed: false, + } + } + + /// Get a reference to the underlying synapse + pub fn synapse(&self) -> &S { + &self.synapse + } + + /// Get a mutable reference to the underlying synapse + pub fn synapse_mut(&mut self) -> &mut S { + &mut self.synapse + } + + /// Check if the stream has completed + pub fn is_completed(&self) -> bool { + self.completed + } +} + +impl Stream for StreamingResponse +where + S: StreamingSynapse + Unpin, + B: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + + if this.completed || this.synapse.is_complete() { + this.synapse.on_stream_end(); + return Poll::Ready(None); + } + + // Try to process any buffered data first + if !this.buffer.is_empty() { + if let Some(chunk) = this.synapse.process_chunk(&this.buffer) { + this.buffer.clear(); + return Poll::Ready(Some(Ok(chunk))); + } + } + + // Poll the underlying stream for more data + match Pin::new(&mut this.byte_stream).poll_next(cx) { + Poll::Ready(Some(Ok(bytes))) => { + // Append to buffer + this.buffer.extend_from_slice(&bytes); + + // Try to process the chunk + if let Some(chunk) = this.synapse.process_chunk(&this.buffer) { + this.buffer.clear(); + Poll::Ready(Some(Ok(chunk))) + } else { + // Need more data, continue polling + cx.waker().wake_by_ref(); + Poll::Pending + } + } + Poll::Ready(Some(Err(e))) => { + this.completed = true; + this.synapse.on_stream_end(); + Poll::Ready(Some(Err(StreamError::Network(e.to_string())))) + } + Poll::Ready(None) => { + // Stream ended + this.completed = true; + this.synapse.on_stream_end(); + + // Process any remaining buffered data + if !this.buffer.is_empty() { + if let Some(chunk) = this.synapse.process_chunk(&this.buffer) { + this.buffer.clear(); + return Poll::Ready(Some(Ok(chunk))); + } + } + + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } +} + +/// Errors that can occur during streaming +#[derive(Debug, thiserror::Error)] +pub enum StreamError { + #[error("Network error: {0}")] + Network(String), + #[error("Parse error: {0}")] + Parse(String), + #[error("Stream timeout")] + Timeout, + #[error("Stream cancelled")] + Cancelled, +} + +/// A simple text streaming synapse implementation +/// +/// Processes newline-delimited text chunks +#[derive(Debug)] +pub struct TextStreamingSynapse { + name: String, + complete: bool, + delimiter: u8, +} + +impl TextStreamingSynapse { + /// Create a new text streaming synapse with newline delimiter + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + complete: false, + delimiter: b'\n', + } + } + + /// Set a custom delimiter byte + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.delimiter = delimiter; + self + } +} + +impl StreamingSynapse for TextStreamingSynapse { + type Chunk = String; + + fn process_chunk(&mut self, chunk: &[u8]) -> Option { + // Look for delimiter + if let Some(pos) = chunk.iter().position(|&b| b == self.delimiter) { + let text = String::from_utf8_lossy(&chunk[..pos]).to_string(); + Some(text) + } else { + None + } + } + + fn is_complete(&self) -> bool { + self.complete + } + + fn name(&self) -> &str { + &self.name + } + + fn on_stream_end(&mut self) { + self.complete = true; + } +} + +/// A JSON streaming synapse implementation +/// +/// Processes newline-delimited JSON objects (NDJSON/JSON Lines format) +#[derive(Debug)] +pub struct JsonStreamingSynapse { + name: String, + complete: bool, + _phantom: std::marker::PhantomData, +} + +impl JsonStreamingSynapse +where + T: serde::de::DeserializeOwned + Send, +{ + /// Create a new JSON streaming synapse + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + complete: false, + _phantom: std::marker::PhantomData, + } + } +} + +impl StreamingSynapse for JsonStreamingSynapse +where + T: serde::de::DeserializeOwned + Send, +{ + type Chunk = T; + + fn process_chunk(&mut self, chunk: &[u8]) -> Option { + // Look for newline-delimited JSON + if let Some(pos) = chunk.iter().position(|&b| b == b'\n') { + let json_bytes = &chunk[..pos]; + if json_bytes.is_empty() { + return None; + } + serde_json::from_slice(json_bytes).ok() + } else { + // Try to parse the entire buffer as a single JSON object + serde_json::from_slice(chunk).ok() + } + } + + fn is_complete(&self) -> bool { + self.complete + } + + fn name(&self) -> &str { + &self.name + } + + fn on_stream_end(&mut self) { + self.complete = true; + } +} + +/// Server-Sent Events (SSE) streaming synapse +/// +/// Processes SSE format: `data: \n\n` +#[derive(Debug)] +pub struct SseStreamingSynapse { + name: String, + complete: bool, +} + +impl SseStreamingSynapse { + /// Create a new SSE streaming synapse + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + complete: false, + } + } +} + +/// An SSE event +#[derive(Debug, Clone)] +pub struct SseEvent { + /// Event type (from `event:` field) + pub event: Option, + /// Event data (from `data:` field) + pub data: String, + /// Event ID (from `id:` field) + pub id: Option, +} + +impl StreamingSynapse for SseStreamingSynapse { + type Chunk = SseEvent; + + fn process_chunk(&mut self, chunk: &[u8]) -> Option { + let text = std::str::from_utf8(chunk).ok()?; + + // Look for double newline which marks end of event + if let Some(pos) = text.find("\n\n") { + let event_text = &text[..pos]; + let mut event = None; + let mut data = String::new(); + let mut id = None; + + for line in event_text.lines() { + if let Some(value) = line.strip_prefix("event:") { + event = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + if !data.is_empty() { + data.push('\n'); + } + data.push_str(value.trim()); + } else if let Some(value) = line.strip_prefix("id:") { + id = Some(value.trim().to_string()); + } + } + + if !data.is_empty() { + return Some(SseEvent { event, data, id }); + } + } + + None + } + + fn is_complete(&self) -> bool { + self.complete + } + + fn name(&self) -> &str { + &self.name + } + + fn on_stream_end(&mut self) { + self.complete = true; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_text_streaming_synapse() { + let mut synapse = TextStreamingSynapse::new("test"); + + // Incomplete chunk (no newline) + assert!(synapse.process_chunk(b"hello").is_none()); + + // Complete chunk + assert_eq!(synapse.process_chunk(b"hello\n"), Some("hello".to_string())); + + // Multiple lines, should only return first + assert_eq!( + synapse.process_chunk(b"line1\nline2\n"), + Some("line1".to_string()) + ); + } + + #[test] + fn test_json_streaming_synapse() { + use serde::Deserialize; + + #[derive(Debug, Deserialize, PartialEq)] + struct TestData { + value: i32, + } + + let mut synapse: JsonStreamingSynapse = JsonStreamingSynapse::new("test"); + + // Valid JSON with newline + let chunk = synapse.process_chunk(br#"{"value": 42}"#.as_slice()); + assert_eq!(chunk, Some(TestData { value: 42 })); + + // NDJSON format + let chunk = synapse.process_chunk(br#"{"value": 100} +"#); + assert_eq!(chunk, Some(TestData { value: 100 })); + } + + #[test] + fn test_sse_streaming_synapse() { + let mut synapse = SseStreamingSynapse::new("test"); + + // Single data event + let chunk = synapse.process_chunk(b"data: hello world\n\n"); + assert!(chunk.is_some()); + let event = chunk.unwrap(); + assert_eq!(event.data, "hello world"); + assert!(event.event.is_none()); + + // Event with type and id + let chunk = synapse.process_chunk(b"event: message\nid: 123\ndata: test data\n\n"); + assert!(chunk.is_some()); + let event = chunk.unwrap(); + assert_eq!(event.event, Some("message".to_string())); + assert_eq!(event.id, Some("123".to_string())); + assert_eq!(event.data, "test data"); + + // Incomplete event (no double newline) + assert!(synapse.process_chunk(b"data: incomplete").is_none()); + } + + #[test] + fn test_sse_multiline_data() { + let mut synapse = SseStreamingSynapse::new("test"); + + // Multi-line data + let chunk = synapse.process_chunk(b"data: line1\ndata: line2\n\n"); + assert!(chunk.is_some()); + let event = chunk.unwrap(); + assert_eq!(event.data, "line1\nline2"); + } + + #[test] + fn test_text_streaming_custom_delimiter() { + let mut synapse = TextStreamingSynapse::new("test").with_delimiter(b'|'); + + assert!(synapse.process_chunk(b"hello\n").is_none()); + assert_eq!(synapse.process_chunk(b"hello|"), Some("hello".to_string())); + } + + #[test] + fn test_streaming_synapse_completion() { + let mut synapse = TextStreamingSynapse::new("test"); + + assert!(!synapse.is_complete()); + synapse.on_stream_end(); + assert!(synapse.is_complete()); + } +} diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..0eea125 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,1647 @@ +//! Comprehensive error types for Bittensor SDK +//! +//! This module provides error types that match the Python SDK exception hierarchy +//! for compatibility and ease of use when porting code between implementations. + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +// ============================================================================= +// Chain/Network Errors +// ============================================================================= + +/// Error when connecting to the RPC endpoint fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Chain connection error: {message}")] +pub struct ChainConnectionError { + /// Detailed error message + pub message: String, + /// The RPC URL that failed to connect + pub rpc_url: Option, +} + +impl ChainConnectionError { + /// Create a new chain connection error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + rpc_url: None, + } + } + + /// Create a new chain connection error with RPC URL + pub fn with_url(message: impl Into, rpc_url: impl Into) -> Self { + Self { + message: message.into(), + rpc_url: Some(rpc_url.into()), + } + } +} + +/// Error when querying chain storage fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Chain query error: {message}")] +pub struct ChainQueryError { + /// Detailed error message + pub message: String, + /// The storage module being queried + pub module: Option, + /// The storage entry being queried + pub entry: Option, +} + +impl ChainQueryError { + /// Create a new chain query error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + module: None, + entry: None, + } + } + + /// Create a new chain query error with module and entry info + pub fn with_storage( + message: impl Into, + module: impl Into, + entry: impl Into, + ) -> Self { + Self { + message: message.into(), + module: Some(module.into()), + entry: Some(entry.into()), + } + } +} + +/// Error when submitting an extrinsic fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Extrinsic error: {message}")] +pub struct ExtrinsicError { + /// Detailed error message + pub message: String, + /// The pallet/module name + pub pallet: Option, + /// The call/function name + pub call: Option, +} + +impl ExtrinsicError { + /// Create a new extrinsic error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + pallet: None, + call: None, + } + } + + /// Create a new extrinsic error with pallet and call info + pub fn with_call( + message: impl Into, + pallet: impl Into, + call: impl Into, + ) -> Self { + Self { + message: message.into(), + pallet: Some(pallet.into()), + call: Some(call.into()), + } + } +} + +/// Error when a transaction failed during execution on chain +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Transaction failed: {message}")] +pub struct TransactionFailed { + /// Detailed error message + pub message: String, + /// The transaction hash if available + pub tx_hash: Option, + /// The dispatch error from the chain + pub dispatch_error: Option, +} + +impl TransactionFailed { + /// Create a new transaction failed error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + tx_hash: None, + dispatch_error: None, + } + } + + /// Create a new transaction failed error with hash + pub fn with_hash(message: impl Into, tx_hash: impl Into) -> Self { + Self { + message: message.into(), + tx_hash: Some(tx_hash.into()), + dispatch_error: None, + } + } + + /// Create a new transaction failed error with dispatch error + pub fn with_dispatch_error( + message: impl Into, + dispatch_error: impl Into, + ) -> Self { + Self { + message: message.into(), + tx_hash: None, + dispatch_error: Some(dispatch_error.into()), + } + } +} + +/// Error when a block is not found +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Block not found: {message}")] +pub struct BlockNotFound { + /// Detailed error message + pub message: String, + /// The block hash that was not found + pub block_hash: Option, + /// The block number that was not found + pub block_number: Option, +} + +impl BlockNotFound { + /// Create a new block not found error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + block_hash: None, + block_number: None, + } + } + + /// Create a new block not found error with hash + pub fn with_hash(message: impl Into, block_hash: impl Into) -> Self { + Self { + message: message.into(), + block_hash: Some(block_hash.into()), + block_number: None, + } + } + + /// Create a new block not found error with number + pub fn with_number(message: impl Into, block_number: u64) -> Self { + Self { + message: message.into(), + block_hash: None, + block_number: Some(block_number), + } + } +} + +/// Error when parsing chain metadata fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Metadata error: {message}")] +pub struct MetadataError { + /// Detailed error message + pub message: String, + /// The metadata version if available + pub metadata_version: Option, +} + +impl MetadataError { + /// Create a new metadata error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + metadata_version: None, + } + } + + /// Create a new metadata error with version + pub fn with_version(message: impl Into, version: u32) -> Self { + Self { + message: message.into(), + metadata_version: Some(version), + } + } +} + +// ============================================================================= +// Wallet Errors +// ============================================================================= + +/// Generic wallet error +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Wallet error: {message}")] +pub struct WalletError { + /// Detailed error message + pub message: String, + /// The wallet name if applicable + pub wallet_name: Option, +} + +impl WalletError { + /// Create a new wallet error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + wallet_name: None, + } + } + + /// Create a new wallet error with wallet name + pub fn with_wallet(message: impl Into, wallet_name: impl Into) -> Self { + Self { + message: message.into(), + wallet_name: Some(wallet_name.into()), + } + } +} + +/// Error when a keyfile is not found +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Keyfile not found: {path}")] +pub struct KeyfileNotFound { + /// The path to the keyfile + pub path: String, + /// The key name (hotkey/coldkey) + pub key_name: Option, +} + +impl KeyfileNotFound { + /// Create a new keyfile not found error + pub fn new(path: impl Into) -> Self { + Self { + path: path.into(), + key_name: None, + } + } + + /// Create a new keyfile not found error with key name + pub fn with_key_name(path: impl Into, key_name: impl Into) -> Self { + Self { + path: path.into(), + key_name: Some(key_name.into()), + } + } +} + +/// Error when decrypting a keyfile fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Keyfile decryption error: {message}")] +pub struct KeyfileDecryptionError { + /// Detailed error message + pub message: String, + /// The keyfile path if available + pub path: Option, +} + +impl KeyfileDecryptionError { + /// Create a new keyfile decryption error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + path: None, + } + } + + /// Create a new keyfile decryption error with path + pub fn with_path(message: impl Into, path: impl Into) -> Self { + Self { + message: message.into(), + path: Some(path.into()), + } + } +} + +/// Error when a mnemonic phrase is invalid +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Invalid mnemonic: {message}")] +pub struct InvalidMnemonic { + /// Detailed error message + pub message: String, + /// The word count if applicable + pub word_count: Option, +} + +impl InvalidMnemonic { + /// Create a new invalid mnemonic error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + word_count: None, + } + } + + /// Create a new invalid mnemonic error with word count + pub fn with_word_count(message: impl Into, word_count: usize) -> Self { + Self { + message: message.into(), + word_count: Some(word_count), + } + } +} + +/// Error when a keyfile is corrupted or invalid +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Invalid keyfile: {message}")] +pub struct InvalidKeyfile { + /// Detailed error message + pub message: String, + /// The keyfile path if available + pub path: Option, +} + +impl InvalidKeyfile { + /// Create a new invalid keyfile error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + path: None, + } + } + + /// Create a new invalid keyfile error with path + pub fn with_path(message: impl Into, path: impl Into) -> Self { + Self { + message: message.into(), + path: Some(path.into()), + } + } +} + +/// Error when file permissions prevent reading/writing a keyfile +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Keyfile permission error: {message}")] +pub struct KeyfilePermissionError { + /// Detailed error message + pub message: String, + /// The keyfile path if available + pub path: Option, + /// The required permission (read/write) + pub required_permission: Option, +} + +impl KeyfilePermissionError { + /// Create a new keyfile permission error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + path: None, + required_permission: None, + } + } + + /// Create a new keyfile permission error with path + pub fn with_path(message: impl Into, path: impl Into) -> Self { + Self { + message: message.into(), + path: Some(path.into()), + required_permission: None, + } + } + + /// Create a new keyfile permission error with permission info + pub fn with_permission( + message: impl Into, + path: impl Into, + permission: impl Into, + ) -> Self { + Self { + message: message.into(), + path: Some(path.into()), + required_permission: Some(permission.into()), + } + } +} + +/// Error when a key already exists (during create operations) +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Key already exists: {message}")] +pub struct KeyExists { + /// Detailed error message + pub message: String, + /// The key name + pub key_name: Option, + /// The keyfile path + pub path: Option, +} + +impl KeyExists { + /// Create a new key exists error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + key_name: None, + path: None, + } + } + + /// Create a new key exists error with key name + pub fn with_key_name(message: impl Into, key_name: impl Into) -> Self { + Self { + message: message.into(), + key_name: Some(key_name.into()), + path: None, + } + } + + /// Create a new key exists error with path + pub fn with_path(message: impl Into, path: impl Into) -> Self { + Self { + message: message.into(), + key_name: None, + path: Some(path.into()), + } + } +} + +// ============================================================================= +// Registration Errors +// ============================================================================= + +/// Error when a hotkey is not registered on a subnet +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Not registered: {message}")] +pub struct NotRegistered { + /// Detailed error message + pub message: String, + /// The hotkey SS58 address + pub hotkey: Option, + /// The subnet UID + pub netuid: Option, +} + +impl NotRegistered { + /// Create a new not registered error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + hotkey: None, + netuid: None, + } + } + + /// Create a new not registered error with hotkey and netuid + pub fn with_details( + message: impl Into, + hotkey: impl Into, + netuid: u16, + ) -> Self { + Self { + message: message.into(), + hotkey: Some(hotkey.into()), + netuid: Some(netuid), + } + } +} + +/// Error when a hotkey is already registered +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Already registered: {message}")] +pub struct AlreadyRegistered { + /// Detailed error message + pub message: String, + /// The hotkey SS58 address + pub hotkey: Option, + /// The subnet UID + pub netuid: Option, +} + +impl AlreadyRegistered { + /// Create a new already registered error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + hotkey: None, + netuid: None, + } + } + + /// Create a new already registered error with hotkey and netuid + pub fn with_details( + message: impl Into, + hotkey: impl Into, + netuid: u16, + ) -> Self { + Self { + message: message.into(), + hotkey: Some(hotkey.into()), + netuid: Some(netuid), + } + } +} + +/// Error when registration transaction fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Registration failed: {message}")] +pub struct RegistrationFailed { + /// Detailed error message + pub message: String, + /// The subnet UID + pub netuid: Option, + /// The dispatch error if available + pub dispatch_error: Option, +} + +impl RegistrationFailed { + /// Create a new registration failed error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + netuid: None, + dispatch_error: None, + } + } + + /// Create a new registration failed error with netuid + pub fn with_netuid(message: impl Into, netuid: u16) -> Self { + Self { + message: message.into(), + netuid: Some(netuid), + dispatch_error: None, + } + } + + /// Create a new registration failed error with dispatch error + pub fn with_dispatch_error( + message: impl Into, + dispatch_error: impl Into, + ) -> Self { + Self { + message: message.into(), + netuid: None, + dispatch_error: Some(dispatch_error.into()), + } + } +} + +/// Error when PoW solution is not found in time +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("PoW failed: {message}")] +pub struct PowFailed { + /// Detailed error message + pub message: String, + /// The difficulty target + pub difficulty: Option, + /// The number of attempts made + pub attempts: Option, + /// Time spent in seconds + pub time_elapsed_secs: Option, +} + +impl PowFailed { + /// Create a new PoW failed error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + difficulty: None, + attempts: None, + time_elapsed_secs: None, + } + } + + /// Create a new PoW failed error with details + pub fn with_details( + message: impl Into, + difficulty: u64, + attempts: u64, + time_elapsed_secs: f64, + ) -> Self { + Self { + message: message.into(), + difficulty: Some(difficulty), + attempts: Some(attempts), + time_elapsed_secs: Some(time_elapsed_secs), + } + } +} + +// ============================================================================= +// Stake Errors +// ============================================================================= + +/// Error when there is insufficient balance for an operation +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Insufficient balance: {message}")] +pub struct InsufficientBalance { + /// Detailed error message + pub message: String, + /// The required amount in RAO + pub required: Option, + /// The available amount in RAO + pub available: Option, +} + +impl InsufficientBalance { + /// Create a new insufficient balance error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + required: None, + available: None, + } + } + + /// Create a new insufficient balance error with amounts + pub fn with_amounts(message: impl Into, required: u64, available: u64) -> Self { + Self { + message: message.into(), + required: Some(required), + available: Some(available), + } + } +} + +/// Error when there is insufficient stake to unstake +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Insufficient stake: {message}")] +pub struct InsufficientStake { + /// Detailed error message + pub message: String, + /// The requested unstake amount in RAO + pub requested: Option, + /// The current stake amount in RAO + pub current_stake: Option, +} + +impl InsufficientStake { + /// Create a new insufficient stake error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + requested: None, + current_stake: None, + } + } + + /// Create a new insufficient stake error with amounts + pub fn with_amounts(message: impl Into, requested: u64, current_stake: u64) -> Self { + Self { + message: message.into(), + requested: Some(requested), + current_stake: Some(current_stake), + } + } +} + +/// Error when a stake operation fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Stake failed: {message}")] +pub struct StakeFailed { + /// Detailed error message + pub message: String, + /// The amount attempted in RAO + pub amount: Option, + /// The dispatch error if available + pub dispatch_error: Option, +} + +impl StakeFailed { + /// Create a new stake failed error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + amount: None, + dispatch_error: None, + } + } + + /// Create a new stake failed error with amount + pub fn with_amount(message: impl Into, amount: u64) -> Self { + Self { + message: message.into(), + amount: Some(amount), + dispatch_error: None, + } + } + + /// Create a new stake failed error with dispatch error + pub fn with_dispatch_error( + message: impl Into, + dispatch_error: impl Into, + ) -> Self { + Self { + message: message.into(), + amount: None, + dispatch_error: Some(dispatch_error.into()), + } + } +} + +// ============================================================================= +// Weights Errors +// ============================================================================= + +/// Generic weights error +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Weights error: {message}")] +pub struct WeightsError { + /// Detailed error message + pub message: String, + /// The subnet UID + pub netuid: Option, +} + +impl WeightsError { + /// Create a new weights error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + netuid: None, + } + } + + /// Create a new weights error with netuid + pub fn with_netuid(message: impl Into, netuid: u16) -> Self { + Self { + message: message.into(), + netuid: Some(netuid), + } + } +} + +/// Error when weights don't normalize properly +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Invalid weights: {message}")] +pub struct InvalidWeights { + /// Detailed error message + pub message: String, + /// The weight sum if relevant + pub weight_sum: Option, + /// Expected sum (typically 1.0 or u16::MAX) + pub expected_sum: Option, +} + +impl InvalidWeights { + /// Create a new invalid weights error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + weight_sum: None, + expected_sum: None, + } + } + + /// Create a new invalid weights error with sum info + pub fn with_sums(message: impl Into, weight_sum: f64, expected_sum: f64) -> Self { + Self { + message: message.into(), + weight_sum: Some(weight_sum), + expected_sum: Some(expected_sum), + } + } +} + +/// Error when weight version key doesn't match +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Weight version mismatch: {message}")] +pub struct WeightVersionMismatch { + /// Detailed error message + pub message: String, + /// The expected version + pub expected_version: Option, + /// The provided version + pub provided_version: Option, +} + +impl WeightVersionMismatch { + /// Create a new weight version mismatch error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + expected_version: None, + provided_version: None, + } + } + + /// Create a new weight version mismatch error with versions + pub fn with_versions(message: impl Into, expected: u64, provided: u64) -> Self { + Self { + message: message.into(), + expected_version: Some(expected), + provided_version: Some(provided), + } + } +} + +/// Error when weight count exceeds maximum allowed +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Too many weights: {message}")] +pub struct TooManyWeights { + /// Detailed error message + pub message: String, + /// The number of weights provided + pub count: Option, + /// The maximum allowed + pub max_allowed: Option, +} + +impl TooManyWeights { + /// Create a new too many weights error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + count: None, + max_allowed: None, + } + } + + /// Create a new too many weights error with counts + pub fn with_counts(message: impl Into, count: usize, max_allowed: usize) -> Self { + Self { + message: message.into(), + count: Some(count), + max_allowed: Some(max_allowed), + } + } +} + +// ============================================================================= +// Synapse/Communication Errors +// ============================================================================= + +/// Generic synapse error +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Synapse error: {message}")] +pub struct SynapseError { + /// Detailed error message + pub message: String, + /// The synapse name if applicable + pub synapse_name: Option, +} + +impl SynapseError { + /// Create a new synapse error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + synapse_name: None, + } + } + + /// Create a new synapse error with synapse name + pub fn with_synapse_name(message: impl Into, synapse_name: impl Into) -> Self { + Self { + message: message.into(), + synapse_name: Some(synapse_name.into()), + } + } +} + +/// Error when a synapse request times out +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Synapse timeout: {message}")] +pub struct SynapseTimeout { + /// Detailed error message + pub message: String, + /// The timeout duration in seconds + pub timeout_secs: Option, + /// The target axon endpoint + pub endpoint: Option, +} + +impl SynapseTimeout { + /// Create a new synapse timeout error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + timeout_secs: None, + endpoint: None, + } + } + + /// Create a new synapse timeout error with details + pub fn with_details( + message: impl Into, + timeout_secs: f64, + endpoint: impl Into, + ) -> Self { + Self { + message: message.into(), + timeout_secs: Some(timeout_secs), + endpoint: Some(endpoint.into()), + } + } +} + +/// Error when synapse signature verification fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Synapse unauthorized: {message}")] +pub struct SynapseUnauthorized { + /// Detailed error message + pub message: String, + /// The hotkey that failed verification + pub hotkey: Option, +} + +impl SynapseUnauthorized { + /// Create a new synapse unauthorized error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + hotkey: None, + } + } + + /// Create a new synapse unauthorized error with hotkey + pub fn with_hotkey(message: impl Into, hotkey: impl Into) -> Self { + Self { + message: message.into(), + hotkey: Some(hotkey.into()), + } + } +} + +/// Error when IP or hotkey is blacklisted +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Synapse blacklisted: {message}")] +pub struct SynapseBlacklisted { + /// Detailed error message + pub message: String, + /// The blacklisted IP if applicable + pub ip: Option, + /// The blacklisted hotkey if applicable + pub hotkey: Option, +} + +impl SynapseBlacklisted { + /// Create a new synapse blacklisted error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + ip: None, + hotkey: None, + } + } + + /// Create a new synapse blacklisted error with IP + pub fn with_ip(message: impl Into, ip: impl Into) -> Self { + Self { + message: message.into(), + ip: Some(ip.into()), + hotkey: None, + } + } + + /// Create a new synapse blacklisted error with hotkey + pub fn with_hotkey(message: impl Into, hotkey: impl Into) -> Self { + Self { + message: message.into(), + ip: None, + hotkey: Some(hotkey.into()), + } + } +} + +/// Error when serialization or deserialization fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Serialization error: {message}")] +pub struct SerializationError { + /// Detailed error message + pub message: String, + /// The type name being serialized/deserialized + pub type_name: Option, +} + +impl SerializationError { + /// Create a new serialization error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + type_name: None, + } + } + + /// Create a new serialization error with type name + pub fn with_type(message: impl Into, type_name: impl Into) -> Self { + Self { + message: message.into(), + type_name: Some(type_name.into()), + } + } +} + +// ============================================================================= +// Dendrite Errors +// ============================================================================= + +/// HTTP client error for dendrite operations +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Dendrite error: {message}")] +pub struct DendriteError { + /// Detailed error message + pub message: String, + /// The HTTP status code if applicable + pub status_code: Option, +} + +impl DendriteError { + /// Create a new dendrite error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + status_code: None, + } + } + + /// Create a new dendrite error with status code + pub fn with_status(message: impl Into, status_code: u16) -> Self { + Self { + message: message.into(), + status_code: Some(status_code), + } + } +} + +/// Error when axon endpoint is unreachable +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Axon unreachable: {message}")] +pub struct AxonUnreachable { + /// Detailed error message + pub message: String, + /// The endpoint that was unreachable + pub endpoint: Option, + /// The IP address + pub ip: Option, + /// The port + pub port: Option, +} + +impl AxonUnreachable { + /// Create a new axon unreachable error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + endpoint: None, + ip: None, + port: None, + } + } + + /// Create a new axon unreachable error with endpoint + pub fn with_endpoint(message: impl Into, endpoint: impl Into) -> Self { + Self { + message: message.into(), + endpoint: Some(endpoint.into()), + ip: None, + port: None, + } + } + + /// Create a new axon unreachable error with IP and port + pub fn with_ip_port(message: impl Into, ip: impl Into, port: u16) -> Self { + Self { + message: message.into(), + endpoint: None, + ip: Some(ip.into()), + port: Some(port), + } + } +} + +/// Error when response from axon is malformed +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Invalid response: {message}")] +pub struct InvalidResponse { + /// Detailed error message + pub message: String, + /// The raw response data if available + pub raw_response: Option, +} + +impl InvalidResponse { + /// Create a new invalid response error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + raw_response: None, + } + } + + /// Create a new invalid response error with raw response + pub fn with_raw_response(message: impl Into, raw_response: impl Into) -> Self { + Self { + message: message.into(), + raw_response: Some(raw_response.into()), + } + } +} + +// ============================================================================= +// Axon Errors +// ============================================================================= + +/// HTTP server error for axon operations +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Axon error: {message}")] +pub struct AxonError { + /// Detailed error message + pub message: String, + /// The HTTP status code if applicable + pub status_code: Option, +} + +impl AxonError { + /// Create a new axon error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + status_code: None, + } + } + + /// Create a new axon error with status code + pub fn with_status(message: impl Into, status_code: u16) -> Self { + Self { + message: message.into(), + status_code: Some(status_code), + } + } +} + +/// Error when axon is not running/serving +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Axon not serving: {message}")] +pub struct AxonNotServing { + /// Detailed error message + pub message: String, + /// The IP the axon should be serving on + pub ip: Option, + /// The port the axon should be serving on + pub port: Option, +} + +impl AxonNotServing { + /// Create a new axon not serving error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + ip: None, + port: None, + } + } + + /// Create a new axon not serving error with IP and port + pub fn with_ip_port(message: impl Into, ip: impl Into, port: u16) -> Self { + Self { + message: message.into(), + ip: Some(ip.into()), + port: Some(port), + } + } +} + +/// Error when axon configuration is invalid +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Axon config error: {message}")] +pub struct AxonConfigError { + /// Detailed error message + pub message: String, + /// The config field that is invalid + pub field: Option, +} + +impl AxonConfigError { + /// Create a new axon config error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + field: None, + } + } + + /// Create a new axon config error with field name + pub fn with_field(message: impl Into, field: impl Into) -> Self { + Self { + message: message.into(), + field: Some(field.into()), + } + } +} + +// ============================================================================= +// Senate/Governance Errors +// ============================================================================= + +/// Error when user is not a senate member +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Not a senate member: {message}")] +pub struct NotSenateMember { + /// Detailed error message + pub message: String, + /// The hotkey SS58 address + pub hotkey: Option, +} + +impl NotSenateMember { + /// Create a new not senate member error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + hotkey: None, + } + } + + /// Create a new not senate member error with hotkey + pub fn with_hotkey(message: impl Into, hotkey: impl Into) -> Self { + Self { + message: message.into(), + hotkey: Some(hotkey.into()), + } + } +} + +/// Error when user is already a senate member +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Already a senate member: {message}")] +pub struct AlreadySenateMember { + /// Detailed error message + pub message: String, + /// The hotkey SS58 address + pub hotkey: Option, +} + +impl AlreadySenateMember { + /// Create a new already senate member error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + hotkey: None, + } + } + + /// Create a new already senate member error with hotkey + pub fn with_hotkey(message: impl Into, hotkey: impl Into) -> Self { + Self { + message: message.into(), + hotkey: Some(hotkey.into()), + } + } +} + +/// Error when a vote operation fails +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Vote failed: {message}")] +pub struct VoteFailed { + /// Detailed error message + pub message: String, + /// The proposal index + pub proposal_index: Option, + /// The dispatch error if available + pub dispatch_error: Option, +} + +impl VoteFailed { + /// Create a new vote failed error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + proposal_index: None, + dispatch_error: None, + } + } + + /// Create a new vote failed error with proposal index + pub fn with_proposal(message: impl Into, proposal_index: u32) -> Self { + Self { + message: message.into(), + proposal_index: Some(proposal_index), + dispatch_error: None, + } + } + + /// Create a new vote failed error with dispatch error + pub fn with_dispatch_error( + message: impl Into, + dispatch_error: impl Into, + ) -> Self { + Self { + message: message.into(), + proposal_index: None, + dispatch_error: Some(dispatch_error.into()), + } + } +} + +/// Error when a proposal is not found +#[derive(Debug, Error, Clone, Serialize, Deserialize)] +#[error("Proposal not found: {message}")] +pub struct ProposalNotFound { + /// Detailed error message + pub message: String, + /// The proposal index + pub proposal_index: Option, + /// The proposal hash if applicable + pub proposal_hash: Option, +} + +impl ProposalNotFound { + /// Create a new proposal not found error + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + proposal_index: None, + proposal_hash: None, + } + } + + /// Create a new proposal not found error with index + pub fn with_index(message: impl Into, proposal_index: u32) -> Self { + Self { + message: message.into(), + proposal_index: Some(proposal_index), + proposal_hash: None, + } + } + + /// Create a new proposal not found error with hash + pub fn with_hash(message: impl Into, proposal_hash: impl Into) -> Self { + Self { + message: message.into(), + proposal_index: None, + proposal_hash: Some(proposal_hash.into()), + } + } +} + +// ============================================================================= +// Unified Error Enum +// ============================================================================= + +/// Unified error type for all Bittensor SDK operations +/// +/// This enum wraps all specific error types and provides a unified interface +/// for error handling. It matches the Python SDK exception hierarchy. +#[derive(Debug, Error, Serialize, Deserialize)] +pub enum BittensorError { + // Chain/Network Errors + #[error(transparent)] + ChainConnection(#[from] ChainConnectionError), + #[error(transparent)] + ChainQuery(#[from] ChainQueryError), + #[error(transparent)] + Extrinsic(#[from] ExtrinsicError), + #[error(transparent)] + TransactionFailed(#[from] TransactionFailed), + #[error(transparent)] + BlockNotFound(#[from] BlockNotFound), + #[error(transparent)] + Metadata(#[from] MetadataError), + + // Wallet Errors + #[error(transparent)] + Wallet(#[from] WalletError), + #[error(transparent)] + KeyfileNotFound(#[from] KeyfileNotFound), + #[error(transparent)] + KeyfileDecryption(#[from] KeyfileDecryptionError), + #[error(transparent)] + InvalidMnemonic(#[from] InvalidMnemonic), + #[error(transparent)] + InvalidKeyfile(#[from] InvalidKeyfile), + #[error(transparent)] + KeyfilePermission(#[from] KeyfilePermissionError), + #[error(transparent)] + KeyExists(#[from] KeyExists), + + // Registration Errors + #[error(transparent)] + NotRegistered(#[from] NotRegistered), + #[error(transparent)] + AlreadyRegistered(#[from] AlreadyRegistered), + #[error(transparent)] + RegistrationFailed(#[from] RegistrationFailed), + #[error(transparent)] + PowFailed(#[from] PowFailed), + + // Stake Errors + #[error(transparent)] + InsufficientBalance(#[from] InsufficientBalance), + #[error(transparent)] + InsufficientStake(#[from] InsufficientStake), + #[error(transparent)] + StakeFailed(#[from] StakeFailed), + + // Weights Errors + #[error(transparent)] + Weights(#[from] WeightsError), + #[error(transparent)] + InvalidWeights(#[from] InvalidWeights), + #[error(transparent)] + WeightVersionMismatch(#[from] WeightVersionMismatch), + #[error(transparent)] + TooManyWeights(#[from] TooManyWeights), + + // Synapse/Communication Errors + #[error(transparent)] + Synapse(#[from] SynapseError), + #[error(transparent)] + SynapseTimeout(#[from] SynapseTimeout), + #[error(transparent)] + SynapseUnauthorized(#[from] SynapseUnauthorized), + #[error(transparent)] + SynapseBlacklisted(#[from] SynapseBlacklisted), + #[error(transparent)] + Serialization(#[from] SerializationError), + + // Dendrite Errors + #[error(transparent)] + Dendrite(#[from] DendriteError), + #[error(transparent)] + AxonUnreachable(#[from] AxonUnreachable), + #[error(transparent)] + InvalidResponse(#[from] InvalidResponse), + + // Axon Errors + #[error(transparent)] + Axon(#[from] AxonError), + #[error(transparent)] + AxonNotServing(#[from] AxonNotServing), + #[error(transparent)] + AxonConfig(#[from] AxonConfigError), + + // Senate/Governance Errors + #[error(transparent)] + NotSenateMember(#[from] NotSenateMember), + #[error(transparent)] + AlreadySenateMember(#[from] AlreadySenateMember), + #[error(transparent)] + VoteFailed(#[from] VoteFailed), + #[error(transparent)] + ProposalNotFound(#[from] ProposalNotFound), + + // External library errors (converted to String for Serialize/Deserialize) + #[error("Subxt error: {0}")] + Subxt(String), + #[error("IO error: {0}")] + Io(String), + #[error("JSON error: {0}")] + Json(String), + #[error("Hex decode error: {0}")] + Hex(String), + #[error("Unknown error: {0}")] + Unknown(String), +} + +// ============================================================================= +// From implementations for external error types +// ============================================================================= + +impl From for BittensorError { + fn from(err: subxt::Error) -> Self { + BittensorError::Subxt(err.to_string()) + } +} + +impl From for BittensorError { + fn from(err: std::io::Error) -> Self { + BittensorError::Io(err.to_string()) + } +} + +impl From for BittensorError { + fn from(err: serde_json::Error) -> Self { + BittensorError::Json(err.to_string()) + } +} + +impl From for BittensorError { + fn from(err: hex::FromHexError) -> Self { + BittensorError::Hex(err.to_string()) + } +} + +// ============================================================================= +// Convenience type alias +// ============================================================================= + +/// Result type alias for Bittensor SDK operations +pub type BittensorResult = Result; + +// ============================================================================= +// Utility functions for error construction +// ============================================================================= + +impl BittensorError { + /// Create an unknown error from any error type + pub fn unknown(err: impl std::fmt::Display) -> Self { + BittensorError::Unknown(err.to_string()) + } + + /// Check if this is a chain connection error + pub fn is_connection_error(&self) -> bool { + matches!(self, BittensorError::ChainConnection(_)) + } + + /// Check if this is an insufficient balance error + pub fn is_insufficient_balance(&self) -> bool { + matches!(self, BittensorError::InsufficientBalance(_)) + } + + /// Check if this is a not registered error + pub fn is_not_registered(&self) -> bool { + matches!(self, BittensorError::NotRegistered(_)) + } + + /// Check if this is a timeout error + pub fn is_timeout(&self) -> bool { + matches!(self, BittensorError::SynapseTimeout(_)) + } + + /// Check if this is an authorization error + pub fn is_unauthorized(&self) -> bool { + matches!( + self, + BittensorError::SynapseUnauthorized(_) | BittensorError::SynapseBlacklisted(_) + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chain_connection_error() { + let err = ChainConnectionError::new("Failed to connect"); + assert_eq!(err.message, "Failed to connect"); + assert!(err.rpc_url.is_none()); + + let err_with_url = + ChainConnectionError::with_url("Connection refused", "wss://example.com:9944"); + assert_eq!(err_with_url.rpc_url, Some("wss://example.com:9944".to_string())); + } + + #[test] + fn test_insufficient_balance_error() { + let err = InsufficientBalance::with_amounts("Not enough TAO", 1000, 500); + assert_eq!(err.required, Some(1000)); + assert_eq!(err.available, Some(500)); + } + + #[test] + fn test_not_registered_error() { + let err = NotRegistered::with_details( + "Hotkey not registered", + "5FHneW46xGXgs5mUiveU4sbTyGBzmstUspZC92UhjJM694ty", + 1, + ); + assert_eq!(err.netuid, Some(1)); + assert!(err.hotkey.is_some()); + } + + #[test] + fn test_bittensor_error_from_chain_connection() { + let err = ChainConnectionError::new("Connection failed"); + let bt_err: BittensorError = err.into(); + assert!(bt_err.is_connection_error()); + } + + #[test] + fn test_bittensor_error_from_subxt() { + // Use a simple approach to verify the conversion works + let bt_err = BittensorError::Subxt("test subxt error".to_string()); + assert!(matches!(bt_err, BittensorError::Subxt(_))); + } + + #[test] + fn test_bittensor_error_from_io() { + let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found"); + let bt_err: BittensorError = io_err.into(); + assert!(matches!(bt_err, BittensorError::Io(_))); + } + + #[test] + fn test_bittensor_error_helper_methods() { + let balance_err = BittensorError::InsufficientBalance(InsufficientBalance::new("test")); + assert!(balance_err.is_insufficient_balance()); + assert!(!balance_err.is_connection_error()); + + let reg_err = BittensorError::NotRegistered(NotRegistered::new("test")); + assert!(reg_err.is_not_registered()); + + let timeout_err = BittensorError::SynapseTimeout(SynapseTimeout::new("test")); + assert!(timeout_err.is_timeout()); + + let unauth_err = BittensorError::SynapseUnauthorized(SynapseUnauthorized::new("test")); + assert!(unauth_err.is_unauthorized()); + } + + #[test] + fn test_error_serialization() { + let err = ChainQueryError::with_storage("Query failed", "SubtensorModule", "TotalStake"); + let serialized = serde_json::to_string(&err).expect("Should serialize"); + let deserialized: ChainQueryError = + serde_json::from_str(&serialized).expect("Should deserialize"); + assert_eq!(err.message, deserialized.message); + assert_eq!(err.module, deserialized.module); + assert_eq!(err.entry, deserialized.entry); + } + + #[test] + fn test_bittensor_error_serialization() { + let err = BittensorError::ChainConnection(ChainConnectionError::new("test")); + let serialized = serde_json::to_string(&err).expect("Should serialize"); + let deserialized: BittensorError = + serde_json::from_str(&serialized).expect("Should deserialize"); + assert!(deserialized.is_connection_error()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 4a5e468..995f5e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,21 +1,32 @@ -#![allow(dead_code, unused_variables, unused_imports)] +pub mod axon; pub mod blocks; pub mod chain; +pub mod cli; pub mod config; pub mod core; pub mod crv4; +pub mod dendrite; +pub mod errors; +pub mod logging; pub mod metagraph; pub mod queries; pub mod subtensor; pub mod types; pub mod utils; pub mod validator; +pub mod wallet; pub use chain::{BittensorClient, Error as ChainError}; -pub use config::{AxonConfig, Config, LoggingConfig, SubtensorConfig}; +pub use config::{AxonConfig, Config, LoggingConfig as ConfigLoggingConfig, SubtensorConfig}; pub use metagraph::{sync_metagraph, Metagraph}; +// Re-export logging module +pub use logging::{ + init_default_logging, init_logging, is_initialized, BittensorFormatter, CompactFormatter, + JsonFormatter, LogFormat, LoggingConfig, +}; + // Re-export types first (includes liquidity types) pub use types::*; @@ -85,3 +96,47 @@ pub use subtensor::{ pub use validator::mechanism::{ commit_mechanism_weights, reveal_mechanism_weights, set_mechanism_weights, }; + +// Re-export Dendrite HTTP client +pub use dendrite::{ + Dendrite, DendriteRequest, DendriteResponse, StreamingResponse, StreamingSynapse, +}; + +// Re-export Axon HTTP server +pub use axon::{ + Axon, AxonConfig as AxonServerConfig, AxonState, HandlerContext, RequestPriority, + VerifiedRequest, AXON_VERSION, +}; + +// Re-export wallet module for key management +pub use wallet::{ + default_wallet_path, is_legacy_format, list_wallets, list_wallets_at, migrate_legacy_keyfile, + wallet_path, Keyfile, KeyfileData, KeyfileError, Keypair, KeypairError, Mnemonic, + MnemonicError, Wallet, WalletError as WalletModuleError, BITTENSOR_SS58_FORMAT, KEYFILE_VERSION, +}; + +// Re-export comprehensive error types +pub use errors::{ + // Unified error type and result alias + BittensorError, BittensorResult, + // Chain/Network Errors + BlockNotFound, ChainConnectionError, ChainQueryError, ExtrinsicError, MetadataError, + TransactionFailed, + // Wallet Errors + InvalidKeyfile, InvalidMnemonic, KeyExists, KeyfileDecryptionError, KeyfileNotFound, + KeyfilePermissionError, WalletError, + // Registration Errors + AlreadyRegistered, NotRegistered, PowFailed, RegistrationFailed, + // Stake Errors + InsufficientBalance, InsufficientStake, StakeFailed, + // Weights Errors + InvalidWeights, TooManyWeights, WeightVersionMismatch, WeightsError, + // Synapse/Communication Errors + SerializationError, SynapseBlacklisted, SynapseError, SynapseTimeout, SynapseUnauthorized, + // Dendrite Errors + AxonUnreachable, DendriteError, InvalidResponse, + // Axon Errors + AxonConfigError, AxonError, AxonNotServing, + // Senate/Governance Errors + AlreadySenateMember, NotSenateMember, ProposalNotFound, VoteFailed, +}; diff --git a/src/logging/format.rs b/src/logging/format.rs new file mode 100644 index 0000000..65d78dc --- /dev/null +++ b/src/logging/format.rs @@ -0,0 +1,255 @@ +//! Custom log formatters for Bittensor SDK +//! +//! Provides log formatting that matches Python SDK output style. + +use std::fmt; +use tracing::{Event, Level, Subscriber}; +use tracing_subscriber::fmt::format::{self, FormatEvent, FormatFields}; +use tracing_subscriber::fmt::FmtContext; +use tracing_subscriber::registry::LookupSpan; + +/// Bittensor-style log formatter matching Python SDK output format. +/// +/// Output format: `YYYY-MM-DD HH:MM:SS | LEVEL | target | message` +/// +/// # Example Output +/// ```text +/// 2024-01-15 10:30:45 | INFO | bittensor::subtensor | Connected to network +/// 2024-01-15 10:30:46 | DEBUG | bittensor::metagraph | Syncing metagraph for netuid 1 +/// ``` +pub struct BittensorFormatter; + +impl FormatEvent for BittensorFormatter +where + S: Subscriber + for<'a> LookupSpan<'a>, + N: for<'a> FormatFields<'a> + 'static, +{ + fn format_event( + &self, + ctx: &FmtContext<'_, S, N>, + mut writer: format::Writer<'_>, + event: &Event<'_>, + ) -> fmt::Result { + // Format: YYYY-MM-DD HH:MM:SS | LEVEL | target | message + let now = chrono::Local::now(); + let level = event.metadata().level(); + let target = event.metadata().target(); + + // Use colored level representation for better readability + let level_str = format_level(*level); + + write!( + writer, + "{} | {} | {} | ", + now.format("%Y-%m-%d %H:%M:%S"), + level_str, + target + )?; + + ctx.field_format().format_fields(writer.by_ref(), event)?; + writeln!(writer) + } +} + +/// Format log level with fixed width for alignment +fn format_level(level: Level) -> &'static str { + match level { + Level::TRACE => "TRACE", + Level::DEBUG => "DEBUG", + Level::INFO => "INFO ", + Level::WARN => "WARN ", + Level::ERROR => "ERROR", + } +} + +/// JSON log formatter for structured logging. +/// +/// Produces newline-delimited JSON (NDJSON) suitable for log aggregation systems. +/// +/// # Example Output +/// ```json +/// {"timestamp":"2024-01-15T10:30:45.123456Z","level":"INFO","target":"bittensor::subtensor","message":"Connected to network"} +/// ``` +pub struct JsonFormatter; + +impl FormatEvent for JsonFormatter +where + S: Subscriber + for<'a> LookupSpan<'a>, + N: for<'a> FormatFields<'a> + 'static, +{ + fn format_event( + &self, + _ctx: &FmtContext<'_, S, N>, + mut writer: format::Writer<'_>, + event: &Event<'_>, + ) -> fmt::Result { + let now = chrono::Utc::now(); + let level = event.metadata().level(); + let target = event.metadata().target(); + + write!( + writer, + "{{\"timestamp\":\"{}\",\"level\":\"{}\",\"target\":\"{}\",\"message\":\"", + now.format("%Y-%m-%dT%H:%M:%S%.6fZ"), + level, + escape_json_string(target) + )?; + + // Capture fields into a string buffer for JSON escaping + let mut field_visitor = JsonFieldVisitor::new(); + event.record(&mut field_visitor); + + write!(writer, "{}", escape_json_string(&field_visitor.message))?; + write!(writer, "\"")?; + + // Add additional fields if present + if !field_visitor.fields.is_empty() { + for (key, value) in &field_visitor.fields { + write!( + writer, + ",\"{}\":\"{}\"", + escape_json_string(key), + escape_json_string(value) + )?; + } + } + + writeln!(writer, "}}") + } +} + +/// Visitor to extract fields from tracing events for JSON formatting +struct JsonFieldVisitor { + message: String, + fields: Vec<(String, String)>, +} + +impl JsonFieldVisitor { + fn new() -> Self { + Self { + message: String::new(), + fields: Vec::new(), + } + } +} + +impl tracing::field::Visit for JsonFieldVisitor { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn fmt::Debug) { + if field.name() == "message" { + self.message = format!("{:?}", value); + // Remove surrounding quotes if present + if self.message.starts_with('"') && self.message.ends_with('"') { + self.message = self.message[1..self.message.len() - 1].to_string(); + } + } else { + self.fields + .push((field.name().to_string(), format!("{:?}", value))); + } + } + + fn record_str(&mut self, field: &tracing::field::Field, value: &str) { + if field.name() == "message" { + self.message = value.to_string(); + } else { + self.fields + .push((field.name().to_string(), value.to_string())); + } + } + + fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { + self.fields + .push((field.name().to_string(), value.to_string())); + } + + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + self.fields + .push((field.name().to_string(), value.to_string())); + } + + fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { + self.fields + .push((field.name().to_string(), value.to_string())); + } +} + +/// Escape special characters for JSON string values +fn escape_json_string(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + for c in s.chars() { + match c { + '"' => result.push_str("\\\""), + '\\' => result.push_str("\\\\"), + '\n' => result.push_str("\\n"), + '\r' => result.push_str("\\r"), + '\t' => result.push_str("\\t"), + c if c.is_control() => { + result.push_str(&format!("\\u{:04x}", c as u32)); + } + c => result.push(c), + } + } + result +} + +/// Compact log formatter with minimal output. +/// +/// Output format: `[LEVEL] message` +/// +/// Useful for development and quick debugging where timestamps aren't needed. +/// +/// # Example Output +/// ```text +/// [INFO] Connected to network +/// [DEBUG] Syncing metagraph +/// ``` +pub struct CompactFormatter; + +impl FormatEvent for CompactFormatter +where + S: Subscriber + for<'a> LookupSpan<'a>, + N: for<'a> FormatFields<'a> + 'static, +{ + fn format_event( + &self, + ctx: &FmtContext<'_, S, N>, + mut writer: format::Writer<'_>, + event: &Event<'_>, + ) -> fmt::Result { + let level = event.metadata().level(); + + write!(writer, "[{}] ", format_level(*level).trim())?; + + ctx.field_format().format_fields(writer.by_ref(), event)?; + writeln!(writer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_level() { + assert_eq!(format_level(Level::TRACE), "TRACE"); + assert_eq!(format_level(Level::DEBUG), "DEBUG"); + assert_eq!(format_level(Level::INFO), "INFO "); + assert_eq!(format_level(Level::WARN), "WARN "); + assert_eq!(format_level(Level::ERROR), "ERROR"); + } + + #[test] + fn test_escape_json_string() { + assert_eq!(escape_json_string("hello"), "hello"); + assert_eq!(escape_json_string("hello\"world"), "hello\\\"world"); + assert_eq!(escape_json_string("line1\nline2"), "line1\\nline2"); + assert_eq!(escape_json_string("path\\to\\file"), "path\\\\to\\\\file"); + assert_eq!(escape_json_string("tab\there"), "tab\\there"); + } + + #[test] + fn test_json_field_visitor() { + let visitor = JsonFieldVisitor::new(); + assert!(visitor.message.is_empty()); + assert!(visitor.fields.is_empty()); + } +} diff --git a/src/logging/mod.rs b/src/logging/mod.rs new file mode 100644 index 0000000..7029543 --- /dev/null +++ b/src/logging/mod.rs @@ -0,0 +1,536 @@ +//! Logging system for Bittensor SDK +//! +//! Provides structured logging matching the Python SDK format with support for +//! multiple output formats (text, JSON, compact) and file logging. +//! +//! # Quick Start +//! +//! ```rust,no_run +//! use bittensor_rs::logging::{init_default_logging, init_logging, LoggingConfig, LogFormat}; +//! +//! // Initialize with defaults (INFO level, text format) +//! init_default_logging(); +//! +//! // Or configure logging explicitly +//! let config = LoggingConfig { +//! debug: true, +//! format: LogFormat::Json, +//! ..Default::default() +//! }; +//! init_logging(&config); +//! ``` +//! +//! # Logging Macros +//! +//! Use the provided macros for consistent logging: +//! +//! ```rust,ignore +//! bt_info!("Connected to network"); +//! bt_debug!(netuid = 1, "Syncing metagraph"); +//! bt_warn!("Connection unstable"); +//! bt_error!(error = %e, "Failed to submit extrinsic"); +//! ``` + +pub mod format; + +use std::io; +use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Once; + +use tracing::Level; +use tracing_appender::non_blocking::WorkerGuard; + +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{fmt, EnvFilter}; + +pub use format::{BittensorFormatter, CompactFormatter, JsonFormatter}; + +/// Static initialization guard to ensure logging is only initialized once +static INIT: Once = Once::new(); + +/// Flag indicating whether logging has been initialized +static INITIALIZED: AtomicBool = AtomicBool::new(false); + +/// Guard for non-blocking file writer (must be kept alive for duration of program) +static mut FILE_GUARD: Option = None; + +/// Log output format +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum LogFormat { + /// Human-readable text format with timestamps + /// Format: `YYYY-MM-DD HH:MM:SS | LEVEL | target | message` + #[default] + Text, + /// JSON format for structured logging and log aggregation + Json, + /// Compact format for development: `[LEVEL] message` + Compact, +} + +impl std::fmt::Display for LogFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogFormat::Text => write!(f, "text"), + LogFormat::Json => write!(f, "json"), + LogFormat::Compact => write!(f, "compact"), + } + } +} + +impl std::str::FromStr for LogFormat { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "text" => Ok(LogFormat::Text), + "json" => Ok(LogFormat::Json), + "compact" => Ok(LogFormat::Compact), + _ => Err(format!( + "Invalid log format '{}'. Valid options: text, json, compact", + s + )), + } + } +} + +/// Logging configuration for the Bittensor SDK +/// +/// This configuration extends the basic `config::LoggingConfig` with format options +/// and provides the initialization logic for the tracing subscriber. +#[derive(Debug, Clone)] +pub struct LoggingConfig { + /// Enable debug-level logging (sets minimum level to DEBUG) + pub debug: bool, + /// Enable trace-level logging (sets minimum level to TRACE, overrides debug) + pub trace: bool, + /// Enable logging to file in addition to stdout + pub record_log: bool, + /// Directory for log files (supports ~ for home directory) + pub logging_dir: String, + /// Output format for log messages + pub format: LogFormat, +} + +impl Default for LoggingConfig { + fn default() -> Self { + Self { + debug: false, + trace: false, + record_log: false, + logging_dir: "~/.bittensor/logs".to_string(), + format: LogFormat::Text, + } + } +} + +impl LoggingConfig { + /// Create a new LoggingConfig with default values + pub fn new() -> Self { + Self::default() + } + + /// Enable debug logging + pub fn with_debug(mut self, debug: bool) -> Self { + self.debug = debug; + self + } + + /// Enable trace logging + pub fn with_trace(mut self, trace: bool) -> Self { + self.trace = trace; + self + } + + /// Enable file logging + pub fn with_file_logging(mut self, enabled: bool) -> Self { + self.record_log = enabled; + self + } + + /// Set the logging directory + pub fn with_logging_dir(mut self, dir: impl Into) -> Self { + self.logging_dir = dir.into(); + self + } + + /// Set the log format + pub fn with_format(mut self, format: LogFormat) -> Self { + self.format = format; + self + } + + /// Load configuration from environment variables + /// + /// Supported environment variables: + /// - `BITTENSOR_LOG_LEVEL`: Set log level (trace, debug, info, warn, error) + /// - `BITTENSOR_LOG_FORMAT`: Set format (text, json, compact) + /// - `BITTENSOR_LOG_DIR`: Set logging directory + /// - `BITTENSOR_DEBUG`: Enable debug mode (any value) + /// - `RUST_LOG`: Standard tracing filter (takes precedence if set) + pub fn from_env() -> Self { + let mut config = Self::default(); + + if std::env::var("BITTENSOR_DEBUG").is_ok() || std::env::var("BITTENSOR_TRACE").is_ok() { + config.debug = true; + } + + if std::env::var("BITTENSOR_TRACE").is_ok() { + config.trace = true; + } + + if let Ok(format) = std::env::var("BITTENSOR_LOG_FORMAT") { + if let Ok(f) = format.parse() { + config.format = f; + } + } + + if let Ok(dir) = std::env::var("BITTENSOR_LOG_DIR") { + config.logging_dir = dir; + config.record_log = true; + } + + config + } + + /// Get the effective log level based on configuration + fn get_level(&self) -> Level { + if self.trace { + Level::TRACE + } else if self.debug { + Level::DEBUG + } else { + Level::INFO + } + } + + /// Expand ~ to home directory in paths + fn expand_path(&self) -> PathBuf { + let path = &self.logging_dir; + if let Some(stripped) = path.strip_prefix("~/") { + if let Some(home) = dirs::home_dir() { + return home.join(stripped); + } + } + PathBuf::from(path) + } +} + +/// Initialize the logging system with the given configuration. +/// +/// This function can only be called once; subsequent calls will be ignored. +/// The logging system uses the `tracing` crate internally and integrates with +/// the standard Rust logging ecosystem. +/// +/// # Arguments +/// +/// * `config` - The logging configuration specifying level, format, and output options +/// +/// # Example +/// +/// ```rust,no_run +/// use bittensor_rs::logging::{init_logging, LoggingConfig, LogFormat}; +/// +/// let config = LoggingConfig { +/// debug: true, +/// format: LogFormat::Text, +/// ..Default::default() +/// }; +/// init_logging(&config); +/// ``` +pub fn init_logging(config: &LoggingConfig) { + INIT.call_once(|| { + init_logging_internal(config); + INITIALIZED.store(true, Ordering::SeqCst); + }); +} + +/// Initialize logging with default configuration (INFO level, text format). +/// +/// This is a convenience function for quick setup. For production use, +/// consider using `init_logging` with explicit configuration. +/// +/// # Example +/// +/// ```rust,no_run +/// use bittensor_rs::logging::init_default_logging; +/// +/// init_default_logging(); +/// ``` +pub fn init_default_logging() { + init_logging(&LoggingConfig::default()); +} + +/// Check if logging has been initialized +pub fn is_initialized() -> bool { + INITIALIZED.load(Ordering::SeqCst) +} + +/// Internal initialization logic +fn init_logging_internal(config: &LoggingConfig) { + // Build environment filter + // Allow RUST_LOG to override, otherwise use config-based level + let env_filter = if std::env::var("RUST_LOG").is_ok() { + EnvFilter::from_default_env() + } else { + let level = config.get_level(); + EnvFilter::new(format!("{},hyper=warn,reqwest=warn,h2=warn", level)) + }; + + // Setup file appender if configured + let file_appender = if config.record_log { + let log_dir = config.expand_path(); + if let Err(e) = std::fs::create_dir_all(&log_dir) { + eprintln!( + "Warning: Failed to create log directory {:?}: {}", + log_dir, e + ); + None + } else { + let file_appender = tracing_appender::rolling::daily(&log_dir, "bittensor.log"); + let (non_blocking, guard) = tracing_appender::non_blocking(file_appender); + // Store the guard to keep the writer alive + // SAFETY: This is only called once via Once::call_once, so there's no race condition + unsafe { + FILE_GUARD = Some(guard); + } + Some(non_blocking) + } + } else { + None + }; + + // Build and initialize the subscriber based on format + match config.format { + LogFormat::Text => { + if let Some(file_writer) = file_appender { + let subscriber = tracing_subscriber::registry() + .with(env_filter) + .with( + fmt::layer() + .event_format(BittensorFormatter) + .with_writer(io::stdout), + ) + .with( + fmt::layer() + .event_format(BittensorFormatter) + .with_writer(file_writer) + .with_ansi(false), + ); + subscriber.init(); + } else { + let subscriber = tracing_subscriber::registry().with(env_filter).with( + fmt::layer() + .event_format(BittensorFormatter) + .with_writer(io::stdout), + ); + subscriber.init(); + } + } + LogFormat::Json => { + if let Some(file_writer) = file_appender { + let subscriber = tracing_subscriber::registry() + .with(env_filter) + .with(fmt::layer().json().with_writer(io::stdout)) + .with( + fmt::layer() + .json() + .with_writer(file_writer) + .with_ansi(false), + ); + subscriber.init(); + } else { + let subscriber = tracing_subscriber::registry() + .with(env_filter) + .with(fmt::layer().json().with_writer(io::stdout)); + subscriber.init(); + } + } + LogFormat::Compact => { + if let Some(file_writer) = file_appender { + let subscriber = tracing_subscriber::registry() + .with(env_filter) + .with( + fmt::layer() + .event_format(CompactFormatter) + .with_writer(io::stdout), + ) + .with( + fmt::layer() + .event_format(CompactFormatter) + .with_writer(file_writer) + .with_ansi(false), + ); + subscriber.init(); + } else { + let subscriber = tracing_subscriber::registry().with(env_filter).with( + fmt::layer() + .event_format(CompactFormatter) + .with_writer(io::stdout), + ); + subscriber.init(); + } + } + } +} + +/// Log a debug message. +/// +/// This macro wraps `tracing::debug!` for consistent usage across the SDK. +/// +/// # Example +/// +/// ```rust,ignore +/// bt_debug!("Processing request"); +/// bt_debug!(netuid = 1, hotkey = %hotkey, "Syncing neuron"); +/// ``` +#[macro_export] +macro_rules! bt_debug { + ($($arg:tt)*) => { + tracing::debug!($($arg)*) + }; +} + +/// Log an info message. +/// +/// This macro wraps `tracing::info!` for consistent usage across the SDK. +/// +/// # Example +/// +/// ```rust,ignore +/// bt_info!("Connected to network"); +/// bt_info!(network = "finney", "Initialized subtensor"); +/// ``` +#[macro_export] +macro_rules! bt_info { + ($($arg:tt)*) => { + tracing::info!($($arg)*) + }; +} + +/// Log a warning message. +/// +/// This macro wraps `tracing::warn!` for consistent usage across the SDK. +/// +/// # Example +/// +/// ```rust,ignore +/// bt_warn!("Connection unstable"); +/// bt_warn!(retries = 3, "Request failed, retrying"); +/// ``` +#[macro_export] +macro_rules! bt_warn { + ($($arg:tt)*) => { + tracing::warn!($($arg)*) + }; +} + +/// Log an error message. +/// +/// This macro wraps `tracing::error!` for consistent usage across the SDK. +/// +/// # Example +/// +/// ```rust,ignore +/// bt_error!("Failed to connect"); +/// bt_error!(error = %e, "Transaction failed"); +/// ``` +#[macro_export] +macro_rules! bt_error { + ($($arg:tt)*) => { + tracing::error!($($arg)*) + }; +} + +/// Log a trace message. +/// +/// This macro wraps `tracing::trace!` for consistent usage across the SDK. +/// Trace-level messages are only emitted when trace logging is enabled. +/// +/// # Example +/// +/// ```rust,ignore +/// bt_trace!("Entering function"); +/// bt_trace!(bytes = data.len(), "Received data"); +/// ``` +#[macro_export] +macro_rules! bt_trace { + ($($arg:tt)*) => { + tracing::trace!($($arg)*) + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_logging_config_default() { + let config = LoggingConfig::default(); + assert!(!config.debug); + assert!(!config.trace); + assert!(!config.record_log); + assert_eq!(config.logging_dir, "~/.bittensor/logs"); + assert_eq!(config.format, LogFormat::Text); + } + + #[test] + fn test_logging_config_builder() { + let config = LoggingConfig::new() + .with_debug(true) + .with_format(LogFormat::Json) + .with_logging_dir("/tmp/logs"); + + assert!(config.debug); + assert_eq!(config.format, LogFormat::Json); + assert_eq!(config.logging_dir, "/tmp/logs"); + } + + #[test] + fn test_log_format_display() { + assert_eq!(format!("{}", LogFormat::Text), "text"); + assert_eq!(format!("{}", LogFormat::Json), "json"); + assert_eq!(format!("{}", LogFormat::Compact), "compact"); + } + + #[test] + fn test_log_format_from_str() { + assert_eq!("text".parse::().unwrap(), LogFormat::Text); + assert_eq!("json".parse::().unwrap(), LogFormat::Json); + assert_eq!("compact".parse::().unwrap(), LogFormat::Compact); + assert_eq!("TEXT".parse::().unwrap(), LogFormat::Text); + assert!("invalid".parse::().is_err()); + } + + #[test] + fn test_get_level() { + let config = LoggingConfig::default(); + assert_eq!(config.get_level(), Level::INFO); + + let config = LoggingConfig::default().with_debug(true); + assert_eq!(config.get_level(), Level::DEBUG); + + let config = LoggingConfig::default().with_trace(true); + assert_eq!(config.get_level(), Level::TRACE); + + // trace takes precedence over debug + let config = LoggingConfig::default().with_debug(true).with_trace(true); + assert_eq!(config.get_level(), Level::TRACE); + } + + #[test] + fn test_expand_path_tilde() { + let config = LoggingConfig::default(); + let path = config.expand_path(); + // Should have expanded ~ to home dir + assert!(!path.to_string_lossy().starts_with('~')); + } + + #[test] + fn test_expand_path_absolute() { + let config = LoggingConfig::default().with_logging_dir("/var/log/bittensor"); + let path = config.expand_path(); + assert_eq!(path.to_string_lossy(), "/var/log/bittensor"); + } +} diff --git a/src/queries/associated_ips.rs b/src/queries/associated_ips.rs new file mode 100644 index 0000000..98d8bcb --- /dev/null +++ b/src/queries/associated_ips.rs @@ -0,0 +1,222 @@ +//! Associated IP address queries +//! +//! This module provides functions to query IP addresses associated with hotkeys +//! on the Bittensor network. + +use crate::chain::BittensorClient; +use crate::errors::{BittensorError, BittensorResult, ChainQueryError}; +use crate::utils::decoders::utils::{extract_u128, extract_u8, parse_ip_addr}; +use parity_scale_codec::Encode; +use sp_core::crypto::AccountId32; +use std::net::IpAddr; +use subxt::dynamic::Value; + +const SUBTENSOR_MODULE: &str = "SubtensorModule"; + +/// IP information associated with a hotkey +#[derive(Debug, Clone)] +pub struct IpInfo { + /// The IP address + pub ip: IpAddr, + /// IP type: 4 for IPv4, 6 for IPv6 + pub ip_type: u8, + /// Protocol identifier + pub protocol: u8, +} + +impl IpInfo { + /// Create a new IpInfo instance + pub fn new(ip: IpAddr, ip_type: u8, protocol: u8) -> Self { + Self { + ip, + ip_type, + protocol, + } + } + + /// Create IpInfo from raw chain data + pub fn from_chain_data(ip_u128: u128, ip_type: u8, protocol: u8) -> Self { + let ip = parse_ip_addr(ip_u128, ip_type); + Self { + ip, + ip_type, + protocol, + } + } + + /// Check if this is an IPv4 address + pub fn is_ipv4(&self) -> bool { + self.ip_type == 4 + } + + /// Check if this is an IPv6 address + pub fn is_ipv6(&self) -> bool { + self.ip_type == 6 + } +} + +/// Decode IpInfo from a Value +/// Chain stores: { ip: u128, ip_type: u8, protocol: u8 } +#[allow(dead_code)] +fn decode_ip_info(value: &Value) -> Option { + let s = format!("{:?}", value); + + // Extract ip (u128), ip_type (u8), protocol (u8) + let ip_u128 = extract_u128(&s, 0)?; + let ip_type = extract_u8(&s, ip_u128.1)?; + let protocol = extract_u8(&s, ip_type.1)?; + + Some(IpInfo::from_chain_data(ip_u128.0, ip_type.0, protocol.0)) +} + +/// Get associated IPs for a hotkey +/// +/// Queries the AssociatedIps storage map which stores a list of IP addresses +/// associated with a given hotkey. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `hotkey` - The hotkey account to query IPs for +/// +/// # Returns +/// A vector of IpInfo structures containing the associated IP addresses +pub async fn get_associated_ips( + client: &BittensorClient, + hotkey: &AccountId32, +) -> BittensorResult> { + let keys = vec![Value::from_bytes(hotkey.encode())]; + + match client + .storage_with_keys(SUBTENSOR_MODULE, "AssociatedIps", keys) + .await + { + Ok(Some(val)) => { + let ips = decode_ip_info_vec(&val); + Ok(ips) + } + Ok(None) => Ok(Vec::new()), + Err(e) => Err(BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query AssociatedIps: {}", e), + SUBTENSOR_MODULE, + "AssociatedIps", + ))), + } +} + +/// Decode a vector of IpInfo from a Value +fn decode_ip_info_vec(value: &Value) -> Vec { + let s = format!("{:?}", value); + let mut result = Vec::new(); + let mut pos = 0; + + // The storage is Vec, so we need to iterate through + // Each IpInfo contains: ip (u128), ip_type (u8), protocol (u8) + loop { + if let Some(ip_u128) = extract_u128(&s, pos) { + if let Some(ip_type) = extract_u8(&s, ip_u128.1) { + if let Some(protocol) = extract_u8(&s, ip_type.1) { + let ip_info = IpInfo::from_chain_data(ip_u128.0, ip_type.0, protocol.0); + result.push(ip_info); + pos = protocol.1; + continue; + } + } + } + break; + } + + result +} + +/// Get the number of associated IPs for a hotkey +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `hotkey` - The hotkey account to query +/// +/// # Returns +/// The count of associated IP addresses +pub async fn get_associated_ip_count( + client: &BittensorClient, + hotkey: &AccountId32, +) -> BittensorResult { + let ips = get_associated_ips(client, hotkey).await?; + Ok(ips.len()) +} + +/// Check if a hotkey has any associated IPs +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `hotkey` - The hotkey account to check +/// +/// # Returns +/// true if the hotkey has at least one associated IP +pub async fn has_associated_ips( + client: &BittensorClient, + hotkey: &AccountId32, +) -> BittensorResult { + let ips = get_associated_ips(client, hotkey).await?; + Ok(!ips.is_empty()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + + #[test] + fn test_ip_info_from_chain_data_ipv4() { + // IPv4 address 192.168.1.1 in u128 representation + let ip_u128: u128 = (192u128 << 24) | (168u128 << 16) | (1u128 << 8) | 1u128; + let ip_info = IpInfo::from_chain_data(ip_u128, 4, 0); + + assert!(ip_info.is_ipv4()); + assert!(!ip_info.is_ipv6()); + assert_eq!(ip_info.ip_type, 4); + assert_eq!(ip_info.protocol, 0); + assert!(matches!(ip_info.ip, IpAddr::V4(_))); + } + + #[test] + fn test_ip_info_from_chain_data_ipv6() { + // IPv6 ::1 in u128 + let ip_u128: u128 = 1; + let ip_info = IpInfo::from_chain_data(ip_u128, 6, 0); + + assert!(!ip_info.is_ipv4()); + assert!(ip_info.is_ipv6()); + assert_eq!(ip_info.ip_type, 6); + assert_eq!(ip_info.protocol, 0); + assert!(matches!(ip_info.ip, IpAddr::V6(_))); + } + + #[test] + fn test_ip_info_new() { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); + let ip_info = IpInfo::new(ip, 4, 1); + + assert_eq!(ip_info.ip, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); + assert_eq!(ip_info.ip_type, 4); + assert_eq!(ip_info.protocol, 1); + } + + #[test] + fn test_ip_info_clone() { + let ip_info = IpInfo::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 4, 0); + let cloned = ip_info.clone(); + + assert_eq!(cloned.ip, ip_info.ip); + assert_eq!(cloned.ip_type, ip_info.ip_type); + assert_eq!(cloned.protocol, ip_info.protocol); + } + + #[test] + fn test_ip_info_debug() { + let ip_info = IpInfo::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 6, 0); + let debug_str = format!("{:?}", ip_info); + + assert!(debug_str.contains("IpInfo")); + assert!(debug_str.contains("::1")); + } +} diff --git a/src/queries/commitments.rs b/src/queries/commitments.rs index 56e655f..c79fa5a 100644 --- a/src/queries/commitments.rs +++ b/src/queries/commitments.rs @@ -1,4 +1,5 @@ use crate::chain::BittensorClient; +use crate::errors::{BittensorError, BittensorResult, ChainQueryError}; use anyhow::Result; use parity_scale_codec::Encode; use sp_core::crypto::AccountId32; @@ -7,6 +8,33 @@ use subxt::dynamic::Value; const SUBTENSOR_MODULE: &str = "SubtensorModule"; const COMMITMENTS_PALLET: &str = "Commitments"; +/// Weight commitment information stored on chain +#[derive(Debug, Clone)] +pub struct WeightCommitInfo { + /// The block number when the commitment was made + pub block: u64, + /// The committed data (typically hash of weights) + pub commit_hash: Vec, + /// The reveal round number + pub reveal_round: u64, +} + +impl WeightCommitInfo { + /// Create a new WeightCommitInfo + pub fn new(block: u64, commit_hash: Vec, reveal_round: u64) -> Self { + Self { + block, + commit_hash, + reveal_round, + } + } + + /// Get the commit hash as a hex string + pub fn commit_hash_hex(&self) -> String { + hex::encode(&self.commit_hash) + } +} + /// Get commitment: SubtensorModule.Commits[(netuid, block, uid)] -> bytes pub async fn get_commitment( client: &BittensorClient, @@ -389,3 +417,254 @@ fn extract_first_account_from_str(s: &str) -> Option { } None } + +/// Get weight commitment for a hotkey on a subnet +/// +/// Queries the CRV3WeightCommits storage to get the weight commitment +/// information for a specific hotkey on a subnet. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `netuid` - The subnet ID +/// * `hotkey` - The hotkey account to query +/// +/// # Returns +/// The WeightCommitInfo if found, None otherwise +pub async fn get_weight_commitment( + client: &BittensorClient, + netuid: u16, + hotkey: &AccountId32, +) -> BittensorResult> { + let keys = vec![ + Value::u128(netuid as u128), + Value::from_bytes(hotkey.encode()), + ]; + + match client + .storage_with_keys(SUBTENSOR_MODULE, "CRV3WeightCommits", keys) + .await + { + Ok(Some(val)) => { + let info = decode_weight_commit_info(&val); + Ok(info) + } + Ok(None) => Ok(None), + Err(e) => Err(BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query CRV3WeightCommits: {}", e), + SUBTENSOR_MODULE, + "CRV3WeightCommits", + ))), + } +} + +/// Decode WeightCommitInfo from a Value +fn decode_weight_commit_info(value: &Value) -> Option { + let s = format!("{:?}", value); + + // Extract block (u64), commit_hash (bytes), reveal_round (u64) + let block = extract_first_u64_from_str(&s).unwrap_or(0); + + // Extract bytes for commit_hash + let mut commit_hash = Vec::new(); + let mut rem = s.as_str(); + while let Some(pos) = rem.find("U128(") { + let after = &rem[pos + 5..]; + if let Some(end) = after.find(')') { + if let Ok(n) = after[..end].trim().parse::() { + if n <= 255 { + commit_hash.push(n as u8); + } + } + rem = &after[end + 1..]; + } else { + break; + } + } + + let reveal_round = extract_last_u64_from_str(&s).unwrap_or(0); + + if block > 0 || !commit_hash.is_empty() || reveal_round > 0 { + Some(WeightCommitInfo::new(block, commit_hash, reveal_round)) + } else { + None + } +} + +/// Get all weight commitments for a subnet +/// +/// Queries all weight commitments from all neurons registered on the subnet. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `netuid` - The subnet ID +/// +/// # Returns +/// A vector of (AccountId32, WeightCommitInfo) tuples for all commitments +pub async fn get_all_weight_commitments( + client: &BittensorClient, + netuid: u16, +) -> BittensorResult> { + // Get the number of neurons in the subnet + let n_val = client + .storage_with_keys( + SUBTENSOR_MODULE, + "SubnetworkN", + vec![Value::u128(netuid as u128)], + ) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query SubnetworkN: {}", e), + SUBTENSOR_MODULE, + "SubnetworkN", + )) + })?; + + let n = n_val + .and_then(|v| crate::utils::decoders::decode_u64(&v).ok()) + .unwrap_or(0); + + let mut commitments = Vec::new(); + + for uid in 0..n { + // Get the hotkey for this UID + let hk_val = client + .storage_with_keys( + SUBTENSOR_MODULE, + "Keys", + vec![Value::u128(netuid as u128), Value::u128(uid as u128)], + ) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query Keys: {}", e), + SUBTENSOR_MODULE, + "Keys", + )) + })?; + + if let Some(hk_val) = hk_val { + if let Ok(hotkey) = crate::utils::decoders::decode_account_id32(&hk_val) { + // Get the commitment for this hotkey + if let Ok(Some(commit_info)) = get_weight_commitment(client, netuid, &hotkey).await + { + commitments.push((hotkey, commit_info)); + } + } + } + } + + Ok(commitments) +} + +/// Get pending weight commits for a subnet +/// +/// Returns weight commits that have been submitted but not yet revealed. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `netuid` - The subnet ID +/// +/// # Returns +/// A vector of (AccountId32, WeightCommitInfo) for pending commits +pub async fn get_pending_weight_commits( + client: &BittensorClient, + netuid: u16, +) -> BittensorResult> { + // Query the V2 commits storage which contains pending commits + let commits_v2 = get_current_weight_commit_info_v2(client, netuid) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query CRV3WeightCommitsV2: {}", e), + SUBTENSOR_MODULE, + "CRV3WeightCommitsV2", + )) + })?; + + let mut result = Vec::new(); + for (hotkey, block, msg, reveal_round) in commits_v2 { + let commit_hash = msg.into_bytes(); + result.push(( + hotkey, + WeightCommitInfo::new(block, commit_hash, reveal_round), + )); + } + + Ok(result) +} + +/// Check if a hotkey has a pending weight commitment on a subnet +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `netuid` - The subnet ID +/// * `hotkey` - The hotkey to check +/// +/// # Returns +/// true if the hotkey has a pending commitment +pub async fn has_pending_commitment( + client: &BittensorClient, + netuid: u16, + hotkey: &AccountId32, +) -> BittensorResult { + let commitment = get_weight_commitment(client, netuid, hotkey).await?; + Ok(commitment.is_some()) +} + +/// Get the last commit block for a hotkey on a subnet +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `netuid` - The subnet ID +/// * `hotkey` - The hotkey to query +/// +/// # Returns +/// The block number of the last commitment, or None if no commitment exists +pub async fn get_last_commit_block( + client: &BittensorClient, + netuid: u16, + hotkey: &AccountId32, +) -> BittensorResult> { + match get_weight_commitment(client, netuid, hotkey).await? { + Some(info) => Ok(Some(info.block)), + None => Ok(None), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_weight_commit_info_new() { + let info = WeightCommitInfo::new(100, vec![1, 2, 3, 4], 5); + assert_eq!(info.block, 100); + assert_eq!(info.commit_hash, vec![1, 2, 3, 4]); + assert_eq!(info.reveal_round, 5); + } + + #[test] + fn test_weight_commit_info_commit_hash_hex() { + let info = WeightCommitInfo::new(100, vec![0xde, 0xad, 0xbe, 0xef], 5); + assert_eq!(info.commit_hash_hex(), "deadbeef"); + } + + #[test] + fn test_weight_commit_info_clone() { + let info = WeightCommitInfo::new(100, vec![1, 2, 3], 5); + let cloned = info.clone(); + assert_eq!(cloned.block, info.block); + assert_eq!(cloned.commit_hash, info.commit_hash); + assert_eq!(cloned.reveal_round, info.reveal_round); + } + + #[test] + fn test_weight_commit_info_debug() { + let info = WeightCommitInfo::new(100, vec![1, 2, 3], 5); + let debug_str = format!("{:?}", info); + assert!(debug_str.contains("WeightCommitInfo")); + assert!(debug_str.contains("100")); + assert!(debug_str.contains("5")); + } +} diff --git a/src/queries/hyperparameters.rs b/src/queries/hyperparameters.rs new file mode 100644 index 0000000..1827826 --- /dev/null +++ b/src/queries/hyperparameters.rs @@ -0,0 +1,420 @@ +//! Subnet hyperparameter queries with proper SCALE decoding +//! +//! This module provides functions to query individual subnet hyperparameters +//! from the Bittensor chain, matching the Python SDK SubnetHyperparameters structure. + +use crate::chain::BittensorClient; +use crate::errors::{BittensorError, BittensorResult, ChainQueryError}; +use crate::utils::decoders::{decode_bool, decode_u16, decode_u64}; +use subxt::dynamic::Value; + +const SUBTENSOR_MODULE: &str = "SubtensorModule"; + +/// Complete subnet hyperparameters (matches Python SDK SubnetHyperparameters) +#[derive(Debug, Clone, Default)] +pub struct SubnetHyperparameters { + pub rho: u16, + pub kappa: u16, + pub immunity_period: u16, + pub min_allowed_weights: u16, + pub max_weights_limit: u16, + pub tempo: u16, + pub min_difficulty: u64, + pub max_difficulty: u64, + pub weights_version: u64, + pub weights_rate_limit: u64, + pub adjustment_interval: u16, + pub activity_cutoff: u16, + pub registration_allowed: bool, + pub target_regs_per_interval: u16, + pub min_burn: u64, + pub max_burn: u64, + pub bonds_moving_avg: u64, + pub max_regs_per_block: u16, + pub serving_rate_limit: u64, + pub max_validators: u16, + pub adjustment_alpha: u64, + pub difficulty: u64, + pub commit_reveal_weights_interval: u64, + pub commit_reveal_weights_enabled: bool, + pub alpha_high: u16, + pub alpha_low: u16, + pub liquid_alpha_enabled: bool, +} + +/// Helper to fetch a u16 storage value for a subnet +async fn fetch_u16_param( + client: &BittensorClient, + entry: &str, + netuid: u16, +) -> BittensorResult { + let keys = vec![Value::u128(netuid as u128)]; + match client.storage_with_keys(SUBTENSOR_MODULE, entry, keys).await { + Ok(Some(val)) => decode_u16(&val).map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to decode {} as u16: {}", entry, e), + SUBTENSOR_MODULE, + entry, + )) + }), + Ok(None) => Ok(0), + Err(e) => Err(BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query {}: {}", entry, e), + SUBTENSOR_MODULE, + entry, + ))), + } +} + +/// Helper to fetch a u64 storage value for a subnet +async fn fetch_u64_param( + client: &BittensorClient, + entry: &str, + netuid: u16, +) -> BittensorResult { + let keys = vec![Value::u128(netuid as u128)]; + match client.storage_with_keys(SUBTENSOR_MODULE, entry, keys).await { + Ok(Some(val)) => decode_u64(&val).map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to decode {} as u64: {}", entry, e), + SUBTENSOR_MODULE, + entry, + )) + }), + Ok(None) => Ok(0), + Err(e) => Err(BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query {}: {}", entry, e), + SUBTENSOR_MODULE, + entry, + ))), + } +} + +/// Helper to fetch a bool storage value for a subnet +async fn fetch_bool_param( + client: &BittensorClient, + entry: &str, + netuid: u16, +) -> BittensorResult { + let keys = vec![Value::u128(netuid as u128)]; + match client.storage_with_keys(SUBTENSOR_MODULE, entry, keys).await { + Ok(Some(val)) => decode_bool(&val).map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to decode {} as bool: {}", entry, e), + SUBTENSOR_MODULE, + entry, + )) + }), + Ok(None) => Ok(false), + Err(e) => Err(BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query {}: {}", entry, e), + SUBTENSOR_MODULE, + entry, + ))), + } +} + +/// Get all hyperparameters for a subnet +pub async fn get_subnet_hyperparameters( + client: &BittensorClient, + netuid: u16, +) -> BittensorResult { + // Fetch all hyperparameters in parallel for efficiency + let ( + rho, + kappa, + immunity_period, + min_allowed_weights, + max_weights_limit, + tempo, + min_difficulty, + max_difficulty, + weights_version, + weights_rate_limit, + adjustment_interval, + activity_cutoff, + registration_allowed, + target_regs_per_interval, + min_burn, + max_burn, + bonds_moving_avg, + max_regs_per_block, + serving_rate_limit, + max_validators, + adjustment_alpha, + difficulty, + commit_reveal_weights_interval, + commit_reveal_weights_enabled, + alpha_high, + alpha_low, + liquid_alpha_enabled, + ) = tokio::join!( + get_rho(client, netuid), + get_kappa(client, netuid), + get_immunity_period(client, netuid), + get_min_allowed_weights(client, netuid), + get_max_weights_limit(client, netuid), + get_tempo(client, netuid), + get_min_difficulty(client, netuid), + get_max_difficulty(client, netuid), + get_weights_version_key(client, netuid), + get_weights_rate_limit(client, netuid), + get_adjustment_interval(client, netuid), + get_activity_cutoff(client, netuid), + get_registration_allowed(client, netuid), + get_target_regs_per_interval(client, netuid), + get_min_burn(client, netuid), + get_max_burn(client, netuid), + get_bonds_moving_average(client, netuid), + get_max_regs_per_block(client, netuid), + get_serving_rate_limit(client, netuid), + get_max_validators(client, netuid), + get_adjustment_alpha(client, netuid), + get_difficulty(client, netuid), + get_commit_reveal_weights_interval(client, netuid), + get_commit_reveal_weights_enabled(client, netuid), + get_alpha_high(client, netuid), + get_alpha_low(client, netuid), + get_liquid_alpha_enabled(client, netuid), + ); + + Ok(SubnetHyperparameters { + rho: rho.unwrap_or(0), + kappa: kappa.unwrap_or(0), + immunity_period: immunity_period.unwrap_or(0), + min_allowed_weights: min_allowed_weights.unwrap_or(0), + max_weights_limit: max_weights_limit.unwrap_or(0), + tempo: tempo.unwrap_or(0), + min_difficulty: min_difficulty.unwrap_or(0), + max_difficulty: max_difficulty.unwrap_or(0), + weights_version: weights_version.unwrap_or(0), + weights_rate_limit: weights_rate_limit.unwrap_or(0), + adjustment_interval: adjustment_interval.unwrap_or(0), + activity_cutoff: activity_cutoff.unwrap_or(0), + registration_allowed: registration_allowed.unwrap_or(false), + target_regs_per_interval: target_regs_per_interval.unwrap_or(0), + min_burn: min_burn.unwrap_or(0), + max_burn: max_burn.unwrap_or(0), + bonds_moving_avg: bonds_moving_avg.unwrap_or(0), + max_regs_per_block: max_regs_per_block.unwrap_or(0), + serving_rate_limit: serving_rate_limit.unwrap_or(0), + max_validators: max_validators.unwrap_or(0), + adjustment_alpha: adjustment_alpha.unwrap_or(0), + difficulty: difficulty.unwrap_or(0), + commit_reveal_weights_interval: commit_reveal_weights_interval.unwrap_or(0), + commit_reveal_weights_enabled: commit_reveal_weights_enabled.unwrap_or(false), + alpha_high: alpha_high.unwrap_or(0), + alpha_low: alpha_low.unwrap_or(0), + liquid_alpha_enabled: liquid_alpha_enabled.unwrap_or(false), + }) +} + +/// Get Rho parameter for a subnet +/// Rho is the ratio for calculating the weights to set +pub async fn get_rho(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "Rho", netuid).await +} + +/// Get Kappa parameter for a subnet +/// Kappa is used in the Yuma Consensus algorithm +pub async fn get_kappa(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "Kappa", netuid).await +} + +/// Get immunity period for a subnet +/// Number of blocks a neuron is protected from deregistration after registration +pub async fn get_immunity_period(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "ImmunityPeriod", netuid).await +} + +/// Get minimum allowed weights for a subnet +/// Minimum number of weights each validator must set +pub async fn get_min_allowed_weights(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "MinAllowedWeights", netuid).await +} + +/// Get maximum weights limit for a subnet +/// Maximum weight value that can be assigned (normalized to u16 range) +pub async fn get_max_weights_limit(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "MaxWeightsLimit", netuid).await +} + +/// Get tempo for a subnet +/// Number of blocks between weight setting epochs +pub async fn get_tempo(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "Tempo", netuid).await +} + +/// Get minimum difficulty for a subnet +/// Minimum PoW difficulty for registration +pub async fn get_min_difficulty(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "MinDifficulty", netuid).await +} + +/// Get maximum difficulty for a subnet +/// Maximum PoW difficulty for registration +pub async fn get_max_difficulty(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "MaxDifficulty", netuid).await +} + +/// Get current difficulty for a subnet +/// Current PoW difficulty for registration +pub async fn get_difficulty(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "Difficulty", netuid).await +} + +/// Get weights version key for a subnet +/// Version number for weight format compatibility +pub async fn get_weights_version_key(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "WeightsVersionKey", netuid).await +} + +/// Get weights rate limit for a subnet +/// Minimum blocks between weight setting transactions +pub async fn get_weights_rate_limit(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "WeightsSetRateLimit", netuid).await +} + +/// Get adjustment interval for a subnet +/// Number of blocks between difficulty adjustments +pub async fn get_adjustment_interval(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "AdjustmentInterval", netuid).await +} + +/// Get activity cutoff for a subnet +/// Number of blocks of inactivity before a neuron becomes inactive +pub async fn get_activity_cutoff(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "ActivityCutoff", netuid).await +} + +/// Check if registration is allowed for a subnet +/// Whether new neurons can register on this subnet +pub async fn get_registration_allowed(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_bool_param(client, "NetworkRegistrationAllowed", netuid).await +} + +/// Get target registrations per interval for a subnet +/// Target number of registrations per adjustment interval +pub async fn get_target_regs_per_interval( + client: &BittensorClient, + netuid: u16, +) -> BittensorResult { + fetch_u16_param(client, "TargetRegistrationsPerInterval", netuid).await +} + +/// Get minimum burn amount for a subnet (in RAO) +/// Minimum amount of TAO to burn for registration +pub async fn get_min_burn(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "MinBurn", netuid).await +} + +/// Get maximum burn amount for a subnet (in RAO) +/// Maximum amount of TAO to burn for registration +pub async fn get_max_burn(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "MaxBurn", netuid).await +} + +/// Get bonds moving average for a subnet +/// Rate at which bonds update (higher = faster updates) +pub async fn get_bonds_moving_average(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "BondsMovingAverage", netuid).await +} + +/// Get maximum registrations per block for a subnet +/// Maximum number of neurons that can register in a single block +pub async fn get_max_regs_per_block(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "MaxRegistrationsPerBlock", netuid).await +} + +/// Get serving rate limit for a subnet +/// Minimum blocks between axon serving info updates +pub async fn get_serving_rate_limit(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "ServingRateLimit", netuid).await +} + +/// Get maximum validators for a subnet +/// Maximum number of validators allowed on the subnet +pub async fn get_max_validators(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "MaxAllowedValidators", netuid).await +} + +/// Get adjustment alpha for a subnet +/// Alpha parameter for difficulty adjustment algorithm +pub async fn get_adjustment_alpha(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u64_param(client, "AdjustmentAlpha", netuid).await +} + +/// Get commit reveal weights interval for a subnet +/// Number of blocks for commit-reveal weight setting cycle +pub async fn get_commit_reveal_weights_interval( + client: &BittensorClient, + netuid: u16, +) -> BittensorResult { + fetch_u64_param(client, "CommitRevealWeightsInterval", netuid).await +} + +/// Check if commit-reveal weights mechanism is enabled for a subnet +pub async fn get_commit_reveal_weights_enabled( + client: &BittensorClient, + netuid: u16, +) -> BittensorResult { + fetch_bool_param(client, "CommitRevealWeightsEnabled", netuid).await +} + +/// Get alpha high parameter for liquid alpha +/// Upper bound for alpha in liquid alpha mechanism +pub async fn get_alpha_high(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "AlphaHigh", netuid).await +} + +/// Get alpha low parameter for liquid alpha +/// Lower bound for alpha in liquid alpha mechanism +pub async fn get_alpha_low(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_u16_param(client, "AlphaLow", netuid).await +} + +/// Check if liquid alpha mechanism is enabled for a subnet +pub async fn get_liquid_alpha_enabled(client: &BittensorClient, netuid: u16) -> BittensorResult { + fetch_bool_param(client, "LiquidAlphaOn", netuid).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_subnet_hyperparameters_default() { + let params = SubnetHyperparameters::default(); + assert_eq!(params.rho, 0); + assert_eq!(params.kappa, 0); + assert_eq!(params.tempo, 0); + assert!(!params.registration_allowed); + assert!(!params.commit_reveal_weights_enabled); + assert!(!params.liquid_alpha_enabled); + } + + #[test] + fn test_subnet_hyperparameters_clone() { + let params = SubnetHyperparameters { + rho: 10, + kappa: 32767, + tempo: 360, + registration_allowed: true, + ..Default::default() + }; + let cloned = params.clone(); + assert_eq!(cloned.rho, 10); + assert_eq!(cloned.kappa, 32767); + assert_eq!(cloned.tempo, 360); + assert!(cloned.registration_allowed); + } + + #[test] + fn test_subnet_hyperparameters_debug() { + let params = SubnetHyperparameters::default(); + let debug_str = format!("{:?}", params); + assert!(debug_str.contains("SubnetHyperparameters")); + assert!(debug_str.contains("rho")); + assert!(debug_str.contains("tempo")); + } +} diff --git a/src/queries/mod.rs b/src/queries/mod.rs index d859bff..da5b4eb 100644 --- a/src/queries/mod.rs +++ b/src/queries/mod.rs @@ -1,8 +1,10 @@ +pub mod associated_ips; pub mod balances; pub mod bonds; pub mod chain_info; pub mod commitments; pub mod delegates; +pub mod hyperparameters; pub mod identity; pub mod liquidity; pub mod metagraph_queries; @@ -31,3 +33,24 @@ pub use subnets::{ commit_reveal_enabled, get_mechanism_count, get_subnet_reveal_period_epochs, is_subnet_active, recycle, }; + +// Re-export hyperparameters +pub use hyperparameters::{ + get_activity_cutoff, get_adjustment_alpha, get_adjustment_interval, get_alpha_high, + get_alpha_low, get_bonds_moving_average, get_commit_reveal_weights_enabled, + get_commit_reveal_weights_interval, get_difficulty, get_immunity_period, get_kappa, + get_liquid_alpha_enabled, get_max_burn, get_max_difficulty, get_max_regs_per_block, + get_max_validators, get_max_weights_limit, get_min_allowed_weights, get_min_burn, + get_min_difficulty, get_registration_allowed, get_rho, get_serving_rate_limit, + get_subnet_hyperparameters, get_target_regs_per_interval, get_tempo, get_weights_rate_limit, + get_weights_version_key, SubnetHyperparameters, +}; + +// Re-export commitment types and functions +pub use commitments::{ + get_all_weight_commitments, get_last_commit_block, get_pending_weight_commits, + get_weight_commitment, has_pending_commitment, WeightCommitInfo, +}; + +// Re-export associated IPs +pub use associated_ips::{get_associated_ip_count, get_associated_ips, has_associated_ips, IpInfo}; diff --git a/src/queries/neurons.rs b/src/queries/neurons.rs index 68b49ea..a83748e 100644 --- a/src/queries/neurons.rs +++ b/src/queries/neurons.rs @@ -14,6 +14,7 @@ const SUBTENSOR_MODULE: &str = "SubtensorModule"; /// SubnetState structure matching the on-chain SCALE encoding from subtensor /// Used to decode the response from SubnetInfoRuntimeApi.get_subnet_state #[derive(Decode, Clone, Debug)] +#[allow(dead_code)] struct SubnetStateRaw { netuid: Compact, hotkeys: Vec, diff --git a/src/subtensor.rs b/src/subtensor.rs index 8c752c0..4a6b001 100644 --- a/src/subtensor.rs +++ b/src/subtensor.rs @@ -22,15 +22,13 @@ use crate::chain::{BittensorClient, BittensorSigner, ExtrinsicWait}; use crate::crv4::{ calculate_reveal_round, commit_timelocked_mechanism_weights, commit_timelocked_weights, - get_mechid_storage_index, prepare_crv4_commit, Crv4CommitData, DEFAULT_COMMIT_REVEAL_VERSION, -}; -use crate::queries::subnets::{ - blocks_since_last_step, commit_reveal_enabled, tempo, weights_rate_limit, + get_mechid_storage_index, prepare_crv4_commit, DEFAULT_COMMIT_REVEAL_VERSION, }; +use crate::queries::subnets::{commit_reveal_enabled, tempo, weights_rate_limit}; use crate::utils::weights::{normalize_weights, U16_MAX}; use crate::validator::weights::{ commit_weights as raw_commit_weights, reveal_weights as raw_reveal_weights, - set_weights as raw_set_weights, CommitRevealData, + set_weights as raw_set_weights, }; use anyhow::Result; use serde::{Deserialize, Serialize}; @@ -38,7 +36,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; -use tracing::{debug, error, info, warn}; +use tracing::{error, info}; const SUBTENSOR_MODULE: &str = "SubtensorModule"; @@ -668,7 +666,9 @@ impl Subtensor { let mut state = self.state.write().await; state.pending_commits.insert(key, pending); if let Some(ref path) = self.state_path { - let _ = state.save(path); + if let Err(e) = state.save(path) { + tracing::warn!("Failed to save state: {}", e); + } } } @@ -703,31 +703,34 @@ impl Subtensor { let uids_u64: Vec = pending.uids.iter().map(|u| *u as u64).collect(); - let tx_hash = if pending.mechanism_id.is_none() || pending.mechanism_id == Some(0) { - raw_reveal_weights( - &self.client, - signer, - pending.netuid, - &uids_u64, - &pending.weights, - &pending.salt, - pending.version_key, - wait_for, - ) - .await? - } else { - crate::reveal_mechanism_weights( - &self.client, - signer, - pending.netuid, - pending.mechanism_id.unwrap(), - &pending.uids, - &pending.weights, - &pending.salt, - pending.version_key, - wait_for, - ) - .await? + let tx_hash = match pending.mechanism_id { + None | Some(0) => { + raw_reveal_weights( + &self.client, + signer, + pending.netuid, + &uids_u64, + &pending.weights, + &pending.salt, + pending.version_key, + wait_for, + ) + .await? + } + Some(mechanism_id) => { + crate::reveal_mechanism_weights( + &self.client, + signer, + pending.netuid, + mechanism_id, + &pending.uids, + &pending.weights, + &pending.salt, + pending.version_key, + wait_for, + ) + .await? + } }; // Remove pending commit @@ -737,7 +740,9 @@ impl Subtensor { state.pending_commits.remove(&key); state.last_revealed.insert(key, pending.epoch); if let Some(ref path) = self.state_path { - let _ = state.save(path); + if let Err(e) = state.save(path) { + tracing::warn!("Failed to save state: {}", e); + } } } @@ -858,7 +863,9 @@ impl Subtensor { .retain(|_, commit| commit.epoch >= cutoff); if let Some(ref path) = self.state_path { - let _ = state.save(path); + if let Err(e) = state.save(path) { + tracing::warn!("Failed to save state: {}", e); + } } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 90c1506..fe0b582 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -25,7 +25,10 @@ pub use neuron_lite::NeuronInfoLite; pub use prometheus::PrometheusInfo; pub use proposal_vote::ProposalVoteData; pub use subnet::{SubnetHyperparameters, SubnetIdentity, SubnetInfo}; -pub use synapse::{Synapse, SynapseHeaders, TerminalInfo}; +pub use synapse::{ + headers, Message, StreamingSynapse, StreamingTextPromptSynapse, Synapse, SynapseHeaders, + SynapseType, TerminalInfo, TextPromptSynapse, +}; /// Chain identity for delegates #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)] diff --git a/src/types/synapse.rs b/src/types/synapse.rs index bdb8b33..0d059e7 100644 --- a/src/types/synapse.rs +++ b/src/types/synapse.rs @@ -1,10 +1,300 @@ //! Synapse types for Bittensor communication -//! These are read-only type definitions that match the Python Synapse class +//! +//! This module provides synapse types that match the Python SDK's Synapse class, +//! supporting both standard request/response patterns and streaming communication. +//! +//! # Features +//! +//! - `Synapse` - Base synapse structure with body hash computation +//! - `SynapseType` - Trait for custom synapse implementations +//! - `StreamingSynapse` - Trait for streaming synapse implementations +//! - `TextPromptSynapse` - Built-in text prompt synapse +//! - `headers` module - Header constants matching Python SDK +//! +//! # Example +//! +//! ```ignore +//! use bittensor_rs::types::synapse::{Synapse, SynapseType, TextPromptSynapse, Message}; +//! +//! // Create a basic synapse +//! let synapse = Synapse::new().with_name("MyQuery").with_timeout(30.0); +//! +//! // Create a text prompt synapse +//! let prompt = TextPromptSynapse::new(vec![ +//! Message::user("Hello, how are you?"), +//! ]); +//! ``` +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; use std::collections::HashMap; +use std::time::Duration; + +use crate::errors::{BittensorError, SerializationError}; + +/// Header constants matching the Python SDK exactly +/// +/// These constants define the HTTP header names used for Bittensor network +/// communication between Dendrite clients and Axon servers. +pub mod headers { + // Dendrite (client) headers + /// Dendrite IP address header + pub const BT_HEADER_DENDRITE_IP: &str = "bt_header_dendrite_ip"; + /// Dendrite port header + pub const BT_HEADER_DENDRITE_PORT: &str = "bt_header_dendrite_port"; + /// Dendrite version header + pub const BT_HEADER_DENDRITE_VERSION: &str = "bt_header_dendrite_version"; + /// Dendrite nonce for replay protection + pub const BT_HEADER_DENDRITE_NONCE: &str = "bt_header_dendrite_nonce"; + /// Dendrite UUID header + pub const BT_HEADER_DENDRITE_UUID: &str = "bt_header_dendrite_uuid"; + /// Dendrite hotkey SS58 address + pub const BT_HEADER_DENDRITE_HOTKEY: &str = "bt_header_dendrite_hotkey"; + /// Dendrite signature for authentication + pub const BT_HEADER_DENDRITE_SIGNATURE: &str = "bt_header_dendrite_signature"; + + // Axon (server) headers + /// Axon IP address header + pub const BT_HEADER_AXON_IP: &str = "bt_header_axon_ip"; + /// Axon port header + pub const BT_HEADER_AXON_PORT: &str = "bt_header_axon_port"; + /// Axon version header + pub const BT_HEADER_AXON_VERSION: &str = "bt_header_axon_version"; + /// Axon nonce for replay protection + pub const BT_HEADER_AXON_NONCE: &str = "bt_header_axon_nonce"; + /// Axon UUID header + pub const BT_HEADER_AXON_UUID: &str = "bt_header_axon_uuid"; + /// Axon hotkey SS58 address + pub const BT_HEADER_AXON_HOTKEY: &str = "bt_header_axon_hotkey"; + /// Axon signature for authentication + pub const BT_HEADER_AXON_SIGNATURE: &str = "bt_header_axon_signature"; + /// Axon status code from response + pub const BT_HEADER_AXON_STATUS_CODE: &str = "bt_header_axon_status_code"; + /// Axon status message from response + pub const BT_HEADER_AXON_STATUS_MESSAGE: &str = "bt_header_axon_status_message"; + /// Axon processing time in seconds + pub const BT_HEADER_AXON_PROCESS_TIME: &str = "bt_header_axon_process_time"; + + // Synapse metadata headers + /// Input object header (serialized synapse input) + pub const BT_HEADER_INPUT_OBJ: &str = "bt_header_input_obj"; + /// Output object header (serialized synapse output) + pub const BT_HEADER_OUTPUT_OBJ: &str = "bt_header_output_obj"; + /// Computed body hash for verification + pub const COMPUTED_BODY_HASH: &str = "computed_body_hash"; + /// Synapse name/route + pub const NAME: &str = "name"; + /// Request timeout in seconds + pub const TIMEOUT: &str = "timeout"; + + /// Get all dendrite header names as a slice + pub fn dendrite_headers() -> &'static [&'static str] { + &[ + BT_HEADER_DENDRITE_IP, + BT_HEADER_DENDRITE_PORT, + BT_HEADER_DENDRITE_VERSION, + BT_HEADER_DENDRITE_NONCE, + BT_HEADER_DENDRITE_UUID, + BT_HEADER_DENDRITE_HOTKEY, + BT_HEADER_DENDRITE_SIGNATURE, + ] + } + + /// Get all axon header names as a slice + pub fn axon_headers() -> &'static [&'static str] { + &[ + BT_HEADER_AXON_IP, + BT_HEADER_AXON_PORT, + BT_HEADER_AXON_VERSION, + BT_HEADER_AXON_NONCE, + BT_HEADER_AXON_UUID, + BT_HEADER_AXON_HOTKEY, + BT_HEADER_AXON_SIGNATURE, + BT_HEADER_AXON_STATUS_CODE, + BT_HEADER_AXON_STATUS_MESSAGE, + BT_HEADER_AXON_PROCESS_TIME, + ] + } + + /// Get all metadata header names as a slice + pub fn metadata_headers() -> &'static [&'static str] { + &[ + BT_HEADER_INPUT_OBJ, + BT_HEADER_OUTPUT_OBJ, + COMPUTED_BODY_HASH, + NAME, + TIMEOUT, + ] + } +} + +// ============================================================================= +// SynapseType Trait +// ============================================================================= + +/// Trait for custom synapse types +/// +/// Types implementing this trait can be used with the Dendrite client +/// for type-safe communication with Axon servers. +/// +/// # Example +/// +/// ```ignore +/// use bittensor_rs::types::synapse::{SynapseType, Synapse}; +/// use serde::{Serialize, Deserialize}; +/// use std::time::Duration; +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// struct MyCustomSynapse { +/// #[serde(flatten)] +/// pub base: Synapse, +/// pub query: String, +/// pub result: Option, +/// } +/// +/// impl SynapseType for MyCustomSynapse { +/// fn name() -> &'static str { "MyCustomSynapse" } +/// fn required_hash_fields() -> Vec<&'static str> { vec!["query"] } +/// } +/// ``` +pub trait SynapseType: Serialize + DeserializeOwned + Send + Sync + 'static { + /// Get the synapse name/route + /// + /// This should return a unique identifier for this synapse type, + /// typically used as the HTTP endpoint path. + fn name() -> &'static str; + + /// Get the timeout for this synapse + /// + /// Default is 12 seconds, matching the Python SDK default. + fn timeout(&self) -> Duration { + Duration::from_secs(12) + } + + /// Get the fields that must be included in body hash computation + /// + /// Returns a list of field names that should be hashed for signature + /// verification. The order matters for consistent hashing. + fn required_hash_fields() -> Vec<&'static str> { + vec![] + } + + /// Deserialize from JSON bytes + /// + /// # Arguments + /// + /// * `data` - JSON-encoded bytes + /// + /// # Returns + /// + /// The deserialized synapse or an error + fn from_json(data: &[u8]) -> Result { + serde_json::from_slice(data).map_err(|e| { + BittensorError::Serialization(SerializationError::with_type( + format!("Failed to deserialize {}: {}", Self::name(), e), + Self::name(), + )) + }) + } + + /// Serialize to JSON bytes + /// + /// # Returns + /// + /// JSON-encoded bytes or an error + fn to_json(&self) -> Result, BittensorError> { + serde_json::to_vec(self).map_err(|e| { + BittensorError::Serialization(SerializationError::with_type( + format!("Failed to serialize {}: {}", Self::name(), e), + Self::name(), + )) + }) + } +} + +// ============================================================================= +// StreamingSynapse Trait +// ============================================================================= + +/// Trait for streaming synapse types that process data in chunks +/// +/// This trait extends `SynapseType` to support streaming responses, +/// allowing incremental processing of large or continuous data streams. +/// +/// # Example +/// +/// ```ignore +/// use bittensor_rs::types::synapse::{SynapseType, StreamingSynapse, Synapse}; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// struct StreamingTextSynapse { +/// #[serde(flatten)] +/// pub base: Synapse, +/// pub prompt: String, +/// #[serde(skip)] +/// accumulated_response: String, +/// #[serde(skip)] +/// complete: bool, +/// } +/// +/// impl SynapseType for StreamingTextSynapse { +/// fn name() -> &'static str { "StreamingTextSynapse" } +/// } +/// +/// impl StreamingSynapse for StreamingTextSynapse { +/// type Chunk = String; +/// +/// fn process_chunk(&mut self, chunk: &[u8]) -> Option { +/// String::from_utf8(chunk.to_vec()).ok() +/// } +/// +/// fn is_complete(&self) -> bool { +/// self.complete +/// } +/// +/// fn finalize(&mut self) -> Result<(), crate::errors::BittensorError> { +/// self.complete = true; +/// Ok(()) +/// } +/// } +/// ``` +pub trait StreamingSynapse: SynapseType { + /// The type of each chunk produced by the stream + type Chunk: Send; + + /// Process a chunk of data from the response stream + /// + /// # Arguments + /// + /// * `chunk` - Raw bytes from the response stream + /// + /// # Returns + /// + /// `Some(chunk)` if a complete chunk was parsed, `None` if more data is needed + fn process_chunk(&mut self, chunk: &[u8]) -> Option; + + /// Check if the stream is complete + /// + /// Returns `true` when no more chunks are expected + fn is_complete(&self) -> bool; + + /// Finalize the stream + /// + /// Called when the stream ends (either normally or due to completion). + /// Implementations should clean up any resources and mark the stream as complete. + fn finalize(&mut self) -> Result<(), BittensorError>; +} + +// ============================================================================= +// TerminalInfo +// ============================================================================= /// Terminal information about a network endpoint +/// +/// Contains metadata about either the Dendrite (client) or Axon (server) +/// side of a communication. Used for authentication and debugging. #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct TerminalInfo { /// HTTP status code @@ -30,19 +320,89 @@ pub struct TerminalInfo { } impl TerminalInfo { + /// Create a new empty TerminalInfo pub fn new() -> Self { Self::default() } + /// Set the status code and message pub fn with_status(mut self, code: i32, message: &str) -> Self { self.status_code = Some(code); self.status_message = Some(message.to_string()); self } + + /// Set the IP address + pub fn with_ip(mut self, ip: impl Into) -> Self { + self.ip = Some(ip.into()); + self + } + + /// Set the port number + pub fn with_port(mut self, port: u16) -> Self { + self.port = Some(port); + self + } + + /// Set the version + pub fn with_version(mut self, version: u64) -> Self { + self.version = Some(version); + self + } + + /// Set the nonce + pub fn with_nonce(mut self, nonce: u64) -> Self { + self.nonce = Some(nonce); + self + } + + /// Set the UUID + pub fn with_uuid(mut self, uuid: impl Into) -> Self { + self.uuid = Some(uuid.into()); + self + } + + /// Set the hotkey + pub fn with_hotkey(mut self, hotkey: impl Into) -> Self { + self.hotkey = Some(hotkey.into()); + self + } + + /// Set the signature + pub fn with_signature(mut self, signature: impl Into) -> Self { + self.signature = Some(signature.into()); + self + } + + /// Set the process time + pub fn with_process_time(mut self, time: f64) -> Self { + self.process_time = Some(time); + self + } } +// ============================================================================= +// Synapse +// ============================================================================= + /// Base Synapse structure for network communication -/// This represents the core message format in Bittensor +/// +/// This represents the core message format in Bittensor, containing +/// metadata for authentication, routing, and debugging, as well as +/// support for custom fields via the `extra` map. +/// +/// # Example +/// +/// ``` +/// use bittensor_rs::types::synapse::Synapse; +/// +/// let synapse = Synapse::new() +/// .with_name("MyQuery") +/// .with_timeout(30.0); +/// +/// assert_eq!(synapse.name, Some("MyQuery".to_string())); +/// assert!(!synapse.is_success()); // No response yet +/// ``` #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Synapse { /// Name of the synapse (route name) @@ -57,7 +417,7 @@ pub struct Synapse { pub dendrite: Option, /// Axon (receiver) terminal information pub axon: Option, - /// Computed body hash + /// Computed body hash for signature verification pub computed_body_hash: Option, /// Additional fields for custom data #[serde(flatten)] @@ -80,21 +440,42 @@ impl Default for Synapse { } impl Synapse { + /// Create a new Synapse with default values pub fn new() -> Self { Self::default() } + /// Set the synapse name/route pub fn with_name(mut self, name: &str) -> Self { self.name = Some(name.to_string()); self } + /// Set the request timeout in seconds pub fn with_timeout(mut self, timeout: f64) -> Self { self.timeout = Some(timeout); self } - /// Check if request was successful + /// Set the dendrite terminal info + pub fn with_dendrite(mut self, dendrite: TerminalInfo) -> Self { + self.dendrite = Some(dendrite); + self + } + + /// Set the axon terminal info + pub fn with_axon(mut self, axon: TerminalInfo) -> Self { + self.axon = Some(axon); + self + } + + /// Set the computed body hash + pub fn with_body_hash(mut self, hash: impl Into) -> Self { + self.computed_body_hash = Some(hash.into()); + self + } + + /// Check if request was successful (status code 200) pub fn is_success(&self) -> bool { if let Some(ref dendrite) = self.dendrite { if let Some(code) = dendrite.status_code { @@ -104,7 +485,7 @@ impl Synapse { false } - /// Check if request failed + /// Check if request failed (status code != 200 or no status) pub fn is_failure(&self) -> bool { if let Some(ref dendrite) = self.dendrite { if let Some(code) = dendrite.status_code { @@ -114,7 +495,7 @@ impl Synapse { true } - /// Check if request timed out + /// Check if request timed out (status code 408) pub fn is_timeout(&self) -> bool { if let Some(ref dendrite) = self.dendrite { if let Some(code) = dendrite.status_code { @@ -124,22 +505,382 @@ impl Synapse { false } - /// Get total size of the synapse + /// Get total size of the synapse (body + header) pub fn get_total_size(&self) -> u64 { self.total_size.unwrap_or(0) + self.header_size.unwrap_or(0) } - /// Set a custom field + /// Set a custom field in the extra map pub fn set_field(&mut self, key: &str, value: serde_json::Value) { self.extra.insert(key.to_string(), value); } - /// Get a custom field + /// Get a custom field from the extra map pub fn get_field(&self, key: &str) -> Option<&serde_json::Value> { self.extra.get(key) } + + /// Compute the body hash for signature verification + /// + /// This method computes a SHA256 hash of specified fields for use in + /// signature verification. The hash matches the Python SDK's body_hash. + /// + /// # Arguments + /// + /// * `fields` - Field names to include in the hash computation + /// + /// # Returns + /// + /// Hex-encoded SHA256 hash string + pub fn compute_body_hash(&self, fields: &[&str]) -> String { + let mut hasher = Sha256::new(); + + // Sort fields for consistent ordering + let mut sorted_fields: Vec<&str> = fields.to_vec(); + sorted_fields.sort(); + + for field in sorted_fields { + if let Some(value) = self.extra.get(field) { + // Serialize the value to JSON for hashing + if let Ok(json_bytes) = serde_json::to_vec(value) { + hasher.update(&json_bytes); + } + } + } + + // Return hex-encoded hash + hex::encode(hasher.finalize()) + } + + /// Compute the body hash using all extra fields + /// + /// Convenience method that hashes all fields in the extra map. + pub fn compute_full_body_hash(&self) -> String { + let fields: Vec<&str> = self.extra.keys().map(|s| s.as_str()).collect(); + self.compute_body_hash(&fields) + } + + /// Verify that the stored body hash matches a computed hash + /// + /// # Arguments + /// + /// * `fields` - Field names to include in the hash computation + /// + /// # Returns + /// + /// `true` if the computed hash matches the stored `computed_body_hash` + pub fn verify_body_hash(&self, fields: &[&str]) -> bool { + if let Some(ref stored_hash) = self.computed_body_hash { + let computed = self.compute_body_hash(fields); + constant_time_compare(stored_hash.as_bytes(), computed.as_bytes()) + } else { + // No stored hash means nothing to verify + false + } + } + + /// Update the stored body hash with a freshly computed value + /// + /// # Arguments + /// + /// * `fields` - Field names to include in the hash computation + pub fn update_body_hash(&mut self, fields: &[&str]) { + self.computed_body_hash = Some(self.compute_body_hash(fields)); + } + + /// Get the timeout as a Duration + pub fn timeout_duration(&self) -> Duration { + Duration::from_secs_f64(self.timeout.unwrap_or(12.0)) + } +} + +/// Constant-time string comparison to prevent timing attacks +fn constant_time_compare(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + + let mut result = 0u8; + for (x, y) in a.iter().zip(b.iter()) { + result |= x ^ y; + } + result == 0 +} + +impl SynapseType for Synapse { + fn name() -> &'static str { + "Synapse" + } + + fn timeout(&self) -> Duration { + self.timeout_duration() + } +} + +// ============================================================================= +// Message +// ============================================================================= + +/// A message in a text prompt conversation +/// +/// Represents a single message in a conversation, with a role (user, assistant, system) +/// and content. This matches the common chat API format used by many LLMs. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct Message { + /// The role of the message author (e.g., "user", "assistant", "system") + pub role: String, + /// The content of the message + pub content: String, +} + +impl Message { + /// Create a new message with the given role and content + pub fn new(role: impl Into, content: impl Into) -> Self { + Self { + role: role.into(), + content: content.into(), + } + } + + /// Create a user message + pub fn user(content: impl Into) -> Self { + Self::new("user", content) + } + + /// Create an assistant message + pub fn assistant(content: impl Into) -> Self { + Self::new("assistant", content) + } + + /// Create a system message + pub fn system(content: impl Into) -> Self { + Self::new("system", content) + } +} + +// ============================================================================= +// TextPromptSynapse +// ============================================================================= + +/// Text prompt synapse for chat/completion requests +/// +/// A pre-built synapse type for text-based LLM interactions, supporting +/// multi-turn conversations with role-based messages. +/// +/// # Example +/// +/// ``` +/// use bittensor_rs::types::synapse::{TextPromptSynapse, Message}; +/// +/// let synapse = TextPromptSynapse::new(vec![ +/// Message::system("You are a helpful assistant."), +/// Message::user("Hello, how are you?"), +/// ]); +/// +/// assert_eq!(synapse.messages.len(), 2); +/// assert!(synapse.response.is_none()); +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TextPromptSynapse { + /// Base synapse fields + #[serde(flatten)] + pub base: Synapse, + /// The conversation messages + pub messages: Vec, + /// The response from the axon (filled by server) + pub response: Option, +} + +impl TextPromptSynapse { + /// Create a new TextPromptSynapse with the given messages + pub fn new(messages: Vec) -> Self { + let mut base = Synapse::new(); + base.name = Some(Self::name().to_string()); + Self { + base, + messages, + response: None, + } + } + + /// Create a new TextPromptSynapse with a single user message + pub fn from_prompt(prompt: impl Into) -> Self { + Self::new(vec![Message::user(prompt)]) + } + + /// Create a new TextPromptSynapse with a system message and user message + pub fn with_system_prompt( + system_prompt: impl Into, + user_prompt: impl Into, + ) -> Self { + Self::new(vec![ + Message::system(system_prompt), + Message::user(user_prompt), + ]) + } + + /// Add a message to the conversation + pub fn add_message(&mut self, message: Message) { + self.messages.push(message); + } + + /// Add a user message to the conversation + pub fn add_user_message(&mut self, content: impl Into) { + self.messages.push(Message::user(content)); + } + + /// Add an assistant message to the conversation + pub fn add_assistant_message(&mut self, content: impl Into) { + self.messages.push(Message::assistant(content)); + } + + /// Set the response + pub fn set_response(&mut self, response: impl Into) { + self.response = Some(response.into()); + } + + /// Get the response if available + pub fn get_response(&self) -> Option<&str> { + self.response.as_deref() + } + + /// Set the timeout + pub fn with_timeout(mut self, timeout: f64) -> Self { + self.base.timeout = Some(timeout); + self + } + + /// Compute the body hash for this synapse + pub fn compute_body_hash(&self) -> String { + let mut hasher = Sha256::new(); + + // Hash the messages array + if let Ok(json_bytes) = serde_json::to_vec(&self.messages) { + hasher.update(&json_bytes); + } + + hex::encode(hasher.finalize()) + } + + /// Update the base synapse body hash + pub fn update_body_hash(&mut self) { + self.base.computed_body_hash = Some(self.compute_body_hash()); + } +} + +impl SynapseType for TextPromptSynapse { + fn name() -> &'static str { + "TextPromptSynapse" + } + + fn timeout(&self) -> Duration { + self.base.timeout_duration() + } + + fn required_hash_fields() -> Vec<&'static str> { + vec!["messages"] + } +} + +// ============================================================================= +// StreamingTextPromptSynapse +// ============================================================================= + +/// Streaming version of TextPromptSynapse for incremental response processing +/// +/// This synapse accumulates text chunks as they arrive, allowing for +/// real-time display of generated text. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamingTextPromptSynapse { + /// Base synapse fields + #[serde(flatten)] + pub base: Synapse, + /// The conversation messages + pub messages: Vec, + /// Accumulated response text + #[serde(default)] + pub accumulated_response: String, + /// Whether the stream is complete + #[serde(skip, default)] + complete: bool, +} + +impl StreamingTextPromptSynapse { + /// Create a new StreamingTextPromptSynapse with the given messages + pub fn new(messages: Vec) -> Self { + let mut base = Synapse::new(); + base.name = Some(Self::name().to_string()); + Self { + base, + messages, + accumulated_response: String::new(), + complete: false, + } + } + + /// Create from a single user prompt + pub fn from_prompt(prompt: impl Into) -> Self { + Self::new(vec![Message::user(prompt)]) + } + + /// Get the current accumulated response + pub fn response(&self) -> &str { + &self.accumulated_response + } + + /// Set the timeout + pub fn with_timeout(mut self, timeout: f64) -> Self { + self.base.timeout = Some(timeout); + self + } +} + +impl SynapseType for StreamingTextPromptSynapse { + fn name() -> &'static str { + "StreamingTextPromptSynapse" + } + + fn timeout(&self) -> Duration { + self.base.timeout_duration() + } + + fn required_hash_fields() -> Vec<&'static str> { + vec!["messages"] + } } +impl StreamingSynapse for StreamingTextPromptSynapse { + type Chunk = String; + + fn process_chunk(&mut self, chunk: &[u8]) -> Option { + // Try to parse the chunk as UTF-8 text + match std::str::from_utf8(chunk) { + Ok(text) => { + if !text.is_empty() { + self.accumulated_response.push_str(text); + Some(text.to_string()) + } else { + None + } + } + Err(_) => None, + } + } + + fn is_complete(&self) -> bool { + self.complete + } + + fn finalize(&mut self) -> Result<(), BittensorError> { + self.complete = true; + Ok(()) + } +} + +// ============================================================================= +// SynapseHeaders +// ============================================================================= + /// HTTP headers for synapse transmission #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct SynapseHeaders { @@ -218,6 +959,10 @@ impl Synapse { mod tests { use super::*; + // ========================================================================= + // Synapse Tests + // ========================================================================= + #[test] fn test_synapse_creation() { let synapse = Synapse::new().with_name("TestSynapse").with_timeout(30.0); @@ -252,4 +997,420 @@ mod tests { assert_eq!(synapse.get_field("input"), Some(&serde_json::json!(42))); } + + #[test] + fn test_synapse_builder_pattern() { + let dendrite = TerminalInfo::new() + .with_ip("192.168.1.1") + .with_port(8080) + .with_hotkey("5FHneW46xGXgs5mUiveU4sbTyGBzmstUspZC92UhjJM694ty"); + + let synapse = Synapse::new() + .with_name("TestSynapse") + .with_timeout(30.0) + .with_dendrite(dendrite); + + assert_eq!(synapse.name, Some("TestSynapse".to_string())); + assert_eq!(synapse.timeout, Some(30.0)); + assert!(synapse.dendrite.is_some()); + let d = synapse.dendrite.as_ref().unwrap(); + assert_eq!(d.ip, Some("192.168.1.1".to_string())); + assert_eq!(d.port, Some(8080)); + } + + #[test] + fn test_timeout_duration() { + let synapse = Synapse::new().with_timeout(30.5); + let duration = synapse.timeout_duration(); + assert_eq!(duration, Duration::from_secs_f64(30.5)); + + let default_synapse = Synapse::new(); + let default_duration = default_synapse.timeout_duration(); + assert_eq!(default_duration, Duration::from_secs(12)); + } + + // ========================================================================= + // Body Hash Tests + // ========================================================================= + + #[test] + fn test_body_hash_computation() { + let mut synapse = Synapse::new(); + synapse.set_field("input", serde_json::json!("hello")); + synapse.set_field("value", serde_json::json!(42)); + + let hash1 = synapse.compute_body_hash(&["input"]); + let hash2 = synapse.compute_body_hash(&["input"]); + + // Same input should produce same hash + assert_eq!(hash1, hash2); + + // Different fields should produce different hash + let hash3 = synapse.compute_body_hash(&["value"]); + assert_ne!(hash1, hash3); + + // Hash should be hex-encoded SHA256 (64 chars) + assert_eq!(hash1.len(), 64); + assert!(hash1.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn test_body_hash_field_ordering() { + let mut synapse = Synapse::new(); + synapse.set_field("a", serde_json::json!(1)); + synapse.set_field("b", serde_json::json!(2)); + + // Fields should be sorted, so order shouldn't matter + let hash1 = synapse.compute_body_hash(&["a", "b"]); + let hash2 = synapse.compute_body_hash(&["b", "a"]); + assert_eq!(hash1, hash2); + } + + #[test] + fn test_body_hash_verification() { + let mut synapse = Synapse::new(); + synapse.set_field("query", serde_json::json!("test query")); + + // Store the hash + synapse.update_body_hash(&["query"]); + + // Verification should pass + assert!(synapse.verify_body_hash(&["query"])); + + // Modify the field + synapse.set_field("query", serde_json::json!("modified query")); + + // Verification should fail after modification + assert!(!synapse.verify_body_hash(&["query"])); + } + + #[test] + fn test_body_hash_empty_fields() { + let synapse = Synapse::new(); + + // Empty fields should still produce a valid hash + let hash = synapse.compute_body_hash(&[]); + assert_eq!(hash.len(), 64); + + // Non-existent fields should be ignored + let hash2 = synapse.compute_body_hash(&["nonexistent"]); + assert_eq!(hash, hash2); + } + + #[test] + fn test_verify_body_hash_no_stored_hash() { + let synapse = Synapse::new(); + + // No stored hash should return false + assert!(!synapse.verify_body_hash(&[])); + } + + // ========================================================================= + // Message Tests + // ========================================================================= + + #[test] + fn test_message_creation() { + let msg = Message::new("user", "Hello, world!"); + assert_eq!(msg.role, "user"); + assert_eq!(msg.content, "Hello, world!"); + } + + #[test] + fn test_message_convenience_constructors() { + let user_msg = Message::user("User message"); + assert_eq!(user_msg.role, "user"); + + let assistant_msg = Message::assistant("Assistant message"); + assert_eq!(assistant_msg.role, "assistant"); + + let system_msg = Message::system("System message"); + assert_eq!(system_msg.role, "system"); + } + + #[test] + fn test_message_serialization() { + let msg = Message::user("Hello"); + let json = serde_json::to_string(&msg).expect("Failed to serialize"); + let deserialized: Message = serde_json::from_str(&json).expect("Failed to deserialize"); + assert_eq!(msg, deserialized); + } + + // ========================================================================= + // TextPromptSynapse Tests + // ========================================================================= + + #[test] + fn test_text_prompt_synapse_creation() { + let synapse = TextPromptSynapse::new(vec![ + Message::system("You are helpful."), + Message::user("Hello!"), + ]); + + assert_eq!(synapse.messages.len(), 2); + assert_eq!(synapse.messages[0].role, "system"); + assert_eq!(synapse.messages[1].role, "user"); + assert!(synapse.response.is_none()); + assert_eq!(synapse.base.name, Some("TextPromptSynapse".to_string())); + } + + #[test] + fn test_text_prompt_synapse_from_prompt() { + let synapse = TextPromptSynapse::from_prompt("What is 2+2?"); + + assert_eq!(synapse.messages.len(), 1); + assert_eq!(synapse.messages[0].role, "user"); + assert_eq!(synapse.messages[0].content, "What is 2+2?"); + } + + #[test] + fn test_text_prompt_synapse_with_system_prompt() { + let synapse = TextPromptSynapse::with_system_prompt( + "You are a math tutor.", + "What is 2+2?", + ); + + assert_eq!(synapse.messages.len(), 2); + assert_eq!(synapse.messages[0].role, "system"); + assert_eq!(synapse.messages[0].content, "You are a math tutor."); + assert_eq!(synapse.messages[1].role, "user"); + } + + #[test] + fn test_text_prompt_synapse_add_messages() { + let mut synapse = TextPromptSynapse::new(vec![]); + + synapse.add_user_message("Hello"); + synapse.add_assistant_message("Hi there!"); + synapse.add_message(Message::new("custom", "Custom role")); + + assert_eq!(synapse.messages.len(), 3); + assert_eq!(synapse.messages[0].role, "user"); + assert_eq!(synapse.messages[1].role, "assistant"); + assert_eq!(synapse.messages[2].role, "custom"); + } + + #[test] + fn test_text_prompt_synapse_response() { + let mut synapse = TextPromptSynapse::from_prompt("Hello"); + + assert!(synapse.get_response().is_none()); + + synapse.set_response("Hello to you too!"); + + assert_eq!(synapse.get_response(), Some("Hello to you too!")); + } + + #[test] + fn test_text_prompt_synapse_body_hash() { + let mut synapse = TextPromptSynapse::new(vec![ + Message::user("Hello"), + ]); + + let hash = synapse.compute_body_hash(); + + // Hash should be valid hex + assert_eq!(hash.len(), 64); + assert!(hash.chars().all(|c| c.is_ascii_hexdigit())); + + // Same messages should produce same hash + let synapse2 = TextPromptSynapse::new(vec![ + Message::user("Hello"), + ]); + assert_eq!(hash, synapse2.compute_body_hash()); + + // Different messages should produce different hash + let synapse3 = TextPromptSynapse::new(vec![ + Message::user("Goodbye"), + ]); + assert_ne!(hash, synapse3.compute_body_hash()); + + // Update body hash should store it + synapse.update_body_hash(); + assert!(synapse.base.computed_body_hash.is_some()); + } + + #[test] + fn test_text_prompt_synapse_serialization() { + let synapse = TextPromptSynapse::new(vec![ + Message::system("Be helpful"), + Message::user("Hello"), + ]).with_timeout(30.0); + + let json = serde_json::to_string(&synapse).expect("Failed to serialize"); + let deserialized: TextPromptSynapse = + serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(synapse.messages.len(), deserialized.messages.len()); + assert_eq!(synapse.base.timeout, deserialized.base.timeout); + } + + #[test] + fn test_text_prompt_synapse_type_trait() { + assert_eq!(TextPromptSynapse::name(), "TextPromptSynapse"); + assert_eq!( + TextPromptSynapse::required_hash_fields(), + vec!["messages"] + ); + + let synapse = TextPromptSynapse::from_prompt("test").with_timeout(45.0); + assert_eq!(synapse.timeout(), Duration::from_secs_f64(45.0)); + } + + // ========================================================================= + // StreamingTextPromptSynapse Tests + // ========================================================================= + + #[test] + fn test_streaming_text_prompt_synapse_creation() { + let synapse = StreamingTextPromptSynapse::new(vec![ + Message::user("Generate text"), + ]); + + assert_eq!(synapse.messages.len(), 1); + assert!(synapse.accumulated_response.is_empty()); + assert!(!synapse.is_complete()); + } + + #[test] + fn test_streaming_text_prompt_synapse_process_chunk() { + let mut synapse = StreamingTextPromptSynapse::from_prompt("test"); + + // Process a chunk + let chunk = synapse.process_chunk(b"Hello "); + assert_eq!(chunk, Some("Hello ".to_string())); + assert_eq!(synapse.response(), "Hello "); + + // Process another chunk + let chunk = synapse.process_chunk(b"World!"); + assert_eq!(chunk, Some("World!".to_string())); + assert_eq!(synapse.response(), "Hello World!"); + } + + #[test] + fn test_streaming_text_prompt_synapse_empty_chunk() { + let mut synapse = StreamingTextPromptSynapse::from_prompt("test"); + + // Empty chunk should return None + let chunk = synapse.process_chunk(b""); + assert!(chunk.is_none()); + } + + #[test] + fn test_streaming_text_prompt_synapse_invalid_utf8() { + let mut synapse = StreamingTextPromptSynapse::from_prompt("test"); + + // Invalid UTF-8 should return None + let chunk = synapse.process_chunk(&[0xFF, 0xFE]); + assert!(chunk.is_none()); + } + + #[test] + fn test_streaming_text_prompt_synapse_finalize() { + let mut synapse = StreamingTextPromptSynapse::from_prompt("test"); + + assert!(!synapse.is_complete()); + + synapse.finalize().expect("Finalize failed"); + + assert!(synapse.is_complete()); + } + + #[test] + fn test_streaming_text_prompt_synapse_type_trait() { + assert_eq!(StreamingTextPromptSynapse::name(), "StreamingTextPromptSynapse"); + assert_eq!( + StreamingTextPromptSynapse::required_hash_fields(), + vec!["messages"] + ); + } + + // ========================================================================= + // TerminalInfo Tests + // ========================================================================= + + #[test] + fn test_terminal_info_builder() { + let info = TerminalInfo::new() + .with_ip("10.0.0.1") + .with_port(9090) + .with_version(123) + .with_nonce(456) + .with_uuid("test-uuid") + .with_hotkey("5FHneW46xGXgs5mUiveU4sbTyGBzmstUspZC92UhjJM694ty") + .with_signature("0x123abc") + .with_process_time(0.5); + + assert_eq!(info.ip, Some("10.0.0.1".to_string())); + assert_eq!(info.port, Some(9090)); + assert_eq!(info.version, Some(123)); + assert_eq!(info.nonce, Some(456)); + assert_eq!(info.uuid, Some("test-uuid".to_string())); + assert_eq!( + info.hotkey, + Some("5FHneW46xGXgs5mUiveU4sbTyGBzmstUspZC92UhjJM694ty".to_string()) + ); + assert_eq!(info.signature, Some("0x123abc".to_string())); + assert_eq!(info.process_time, Some(0.5)); + } + + // ========================================================================= + // Header Constants Tests + // ========================================================================= + + #[test] + fn test_header_constants() { + // Verify header names match expected format + assert!(headers::BT_HEADER_DENDRITE_IP.starts_with("bt_header_")); + assert!(headers::BT_HEADER_AXON_IP.starts_with("bt_header_")); + + // Verify helper functions return correct headers + let dendrite_headers = headers::dendrite_headers(); + assert!(dendrite_headers.contains(&headers::BT_HEADER_DENDRITE_IP)); + assert!(dendrite_headers.contains(&headers::BT_HEADER_DENDRITE_SIGNATURE)); + + let axon_headers = headers::axon_headers(); + assert!(axon_headers.contains(&headers::BT_HEADER_AXON_IP)); + assert!(axon_headers.contains(&headers::BT_HEADER_AXON_PROCESS_TIME)); + + let metadata_headers = headers::metadata_headers(); + assert!(metadata_headers.contains(&headers::COMPUTED_BODY_HASH)); + assert!(metadata_headers.contains(&headers::NAME)); + } + + // ========================================================================= + // SynapseType Trait Tests + // ========================================================================= + + #[test] + fn test_synapse_type_json_roundtrip() { + let synapse = Synapse::new().with_name("Test").with_timeout(30.0); + + // Use SynapseType methods + let json_bytes = synapse.to_json().expect("Serialization failed"); + let deserialized = Synapse::from_json(&json_bytes).expect("Deserialization failed"); + + assert_eq!(synapse.name, deserialized.name); + assert_eq!(synapse.timeout, deserialized.timeout); + } + + #[test] + fn test_synapse_type_from_json_error() { + let invalid_json = b"not valid json {"; + let result = Synapse::from_json(invalid_json); + assert!(result.is_err()); + } + + // ========================================================================= + // Constant Time Compare Tests + // ========================================================================= + + #[test] + fn test_constant_time_compare() { + assert!(constant_time_compare(b"hello", b"hello")); + assert!(!constant_time_compare(b"hello", b"world")); + assert!(!constant_time_compare(b"hello", b"hell")); + assert!(!constant_time_compare(b"", b"a")); + assert!(constant_time_compare(b"", b"")); + } } diff --git a/src/utils/decoders/composite.rs b/src/utils/decoders/composite.rs index 9fa799d..d73d666 100644 --- a/src/utils/decoders/composite.rs +++ b/src/utils/decoders/composite.rs @@ -1,24 +1,68 @@ use super::utils; use anyhow::{anyhow, Result}; use std::collections::HashMap; -use subxt::dynamic::Value; +use std::sync::OnceLock; +use subxt::dynamic::{At, Value}; +use subxt::ext::scale_value::{Composite, ValueDef}; + +/// Get a static regex for parsing identity data +fn get_identity_regex() -> &'static regex::Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| { + regex::Regex::new(r#"(\w+):\s*(?:Some\()?"([^"]*)"(?:\))?"#) + .expect("Invalid regex pattern for identity parsing") + }) +} + +/// Extract a u128 primitive from a Value using the proper API +fn extract_u128_from_value(value: &Value) -> Option { + value.as_u128() +} /// Decode PrometheusInfo from Value /// Subtensor PrometheusInfo: { block: u64, version: u32, ip: u128, port: u16, ip_type: u8 } pub fn decode_prometheus_info(value: &Value) -> Result { + // Try using Value's .at() API first for named/unnamed composite access + // PrometheusInfo fields in order: block, version, ip, port, ip_type + if let (Some(block_val), Some(version_val), Some(ip_val), Some(port_val), Some(ip_type_val)) = ( + value.at(0), + value.at(1), + value.at(2), + value.at(3), + value.at(4), + ) { + if let (Some(block), Some(version), Some(ip_u128), Some(port), Some(ip_type)) = ( + extract_u128_from_value(block_val), + extract_u128_from_value(version_val), + extract_u128_from_value(ip_val), + extract_u128_from_value(port_val), + extract_u128_from_value(ip_type_val), + ) { + let ip = utils::parse_ip_addr(ip_u128, ip_type as u8); + return Ok(crate::types::PrometheusInfo::from_chain_data( + block as u64, + version as u32, + ip.to_string(), + port as u16, + ip_type as u8, + )); + } + } + + // Fall back to debug string parsing for compatibility with older formats let s = format!("{:?}", value); // Extract exactly: U64, U32, U128, U16, U8 - let block = - utils::extract_u64(&s, 0).ok_or_else(|| anyhow!("PrometheusInfo: missing block (u64)"))?; + let block = utils::extract_u64(&s, 0) + .ok_or_else(|| anyhow!("PrometheusInfo: missing block (u64) in value: {}", s))?; let version = utils::extract_u32(&s, block.1) - .ok_or_else(|| anyhow!("PrometheusInfo: missing version (u32)"))?; + .ok_or_else(|| anyhow!("PrometheusInfo: missing version (u32) in value: {}", s))?; let ip_u128 = utils::extract_u128(&s, version.1) - .ok_or_else(|| anyhow!("PrometheusInfo: missing ip (u128)"))?; + .ok_or_else(|| anyhow!("PrometheusInfo: missing ip (u128) in value: {}", s))?; let port = utils::extract_u16(&s, ip_u128.1) - .ok_or_else(|| anyhow!("PrometheusInfo: missing port (u16)"))?; + .ok_or_else(|| anyhow!("PrometheusInfo: missing port (u16) in value: {}", s))?; let ip_type = utils::extract_u8(&s, port.1) - .ok_or_else(|| anyhow!("PrometheusInfo: missing ip_type (u8)"))?; + .ok_or_else(|| anyhow!("PrometheusInfo: missing ip_type (u8) in value: {}", s))?; let ip = utils::parse_ip_addr(ip_u128.0, ip_type.0); @@ -34,25 +78,80 @@ pub fn decode_prometheus_info(value: &Value) -> Result Result { + // Try using Value's .at() API first for composite access + // AxonInfo fields in order: block, version, ip, port, ip_type, protocol, placeholder1, placeholder2 + if let ( + Some(block_val), + Some(version_val), + Some(ip_val), + Some(port_val), + Some(ip_type_val), + Some(protocol_val), + Some(placeholder1_val), + Some(placeholder2_val), + ) = ( + value.at(0), + value.at(1), + value.at(2), + value.at(3), + value.at(4), + value.at(5), + value.at(6), + value.at(7), + ) { + if let ( + Some(block), + Some(version), + Some(ip_u128), + Some(port), + Some(ip_type), + Some(protocol), + Some(placeholder1), + Some(placeholder2), + ) = ( + extract_u128_from_value(block_val), + extract_u128_from_value(version_val), + extract_u128_from_value(ip_val), + extract_u128_from_value(port_val), + extract_u128_from_value(ip_type_val), + extract_u128_from_value(protocol_val), + extract_u128_from_value(placeholder1_val), + extract_u128_from_value(placeholder2_val), + ) { + let ip = utils::parse_ip_addr(ip_u128, ip_type as u8); + return Ok(crate::types::AxonInfo::from_chain_data( + block as u64, + version as u32, + ip, + port as u16, + ip_type as u8, + protocol as u8, + placeholder1 as u8, + placeholder2 as u8, + )); + } + } + + // Fall back to debug string parsing for compatibility let s = format!("{:?}", value); // Extract exactly: U64, U32, U128, U16, U8, U8, U8, U8 - let block = - utils::extract_u64(&s, 0).ok_or_else(|| anyhow!("AxonInfo: missing block (u64)"))?; + let block = utils::extract_u64(&s, 0) + .ok_or_else(|| anyhow!("AxonInfo: missing block (u64) in value: {}", s))?; let version = utils::extract_u32(&s, block.1) - .ok_or_else(|| anyhow!("AxonInfo: missing version (u32)"))?; - let ip_u128 = - utils::extract_u128(&s, version.1).ok_or_else(|| anyhow!("AxonInfo: missing ip (u128)"))?; - let port = - utils::extract_u16(&s, ip_u128.1).ok_or_else(|| anyhow!("AxonInfo: missing port (u16)"))?; - let ip_type = - utils::extract_u8(&s, port.1).ok_or_else(|| anyhow!("AxonInfo: missing ip_type (u8)"))?; + .ok_or_else(|| anyhow!("AxonInfo: missing version (u32) in value: {}", s))?; + let ip_u128 = utils::extract_u128(&s, version.1) + .ok_or_else(|| anyhow!("AxonInfo: missing ip (u128) in value: {}", s))?; + let port = utils::extract_u16(&s, ip_u128.1) + .ok_or_else(|| anyhow!("AxonInfo: missing port (u16) in value: {}", s))?; + let ip_type = utils::extract_u8(&s, port.1) + .ok_or_else(|| anyhow!("AxonInfo: missing ip_type (u8) in value: {}", s))?; let protocol = utils::extract_u8(&s, ip_type.1) - .ok_or_else(|| anyhow!("AxonInfo: missing protocol (u8)"))?; + .ok_or_else(|| anyhow!("AxonInfo: missing protocol (u8) in value: {}", s))?; let placeholder1 = utils::extract_u8(&s, protocol.1) - .ok_or_else(|| anyhow!("AxonInfo: missing placeholder1 (u8)"))?; + .ok_or_else(|| anyhow!("AxonInfo: missing placeholder1 (u8) in value: {}", s))?; let placeholder2 = utils::extract_u8(&s, placeholder1.1) - .ok_or_else(|| anyhow!("AxonInfo: missing placeholder2 (u8)"))?; + .ok_or_else(|| anyhow!("AxonInfo: missing placeholder2 (u8) in value: {}", s))?; let ip = utils::parse_ip_addr(ip_u128.0, ip_type.0); @@ -69,14 +168,283 @@ pub fn decode_axon_info(value: &Value) -> Result { } /// Helper to decode identity data from a map structure -/// TODO: Implement actual field extraction -pub fn decode_identity_map(_value: &Value) -> Result> { - Ok(HashMap::new()) +/// Extracts key-value pairs from composite values +pub fn decode_identity_map(value: &Value) -> Result> { + let mut result = HashMap::new(); + + // Try to extract named fields first using the proper Value API + // Check for common identity fields by name + let identity_fields = [ + "name", + "display", + "legal", + "web", + "riot", + "email", + "image", + "twitter", + "pgp_fingerprint", + "additional", + ]; + + for field_name in identity_fields { + if let Some(field_val) = value.at(field_name) { + // Try to get the string value directly + if let Some(s) = field_val.as_str() { + if !s.is_empty() { + result.insert(field_name.to_string(), s.to_string()); + } + } else if let Some(inner) = field_val.at(0) { + // Handle Option wrapped values (Some variant with inner value) + if let Some(s) = inner.as_str() { + if !s.is_empty() { + result.insert(field_name.to_string(), s.to_string()); + } + } + } + } + } + + // If we found fields using the API, return early + if !result.is_empty() { + return Ok(result); + } + + // Fall back to regex parsing for compatibility with debug format + let value_str = format!("{:?}", value); + let re = get_identity_regex(); + + for cap in re.captures_iter(&value_str) { + if let (Some(key), Some(val)) = (cap.get(1), cap.get(2)) { + let key_str = key.as_str().to_string(); + let val_str = val.as_str().to_string(); + // Only insert non-empty values + if !val_str.is_empty() { + result.insert(key_str, val_str); + } + } + } + + Ok(result) } /// Decode a named composite (struct) from a Value -/// Returns empty HashMap if value is not a named composite -/// TODO: Implement actual field extraction -pub fn decode_named_composite(_value: &Value) -> Result> { - Ok(HashMap::new()) +/// Extracts field names and values from composite structures +/// For named composites: returns HashMap of field_name -> cloned Value +/// For unnamed composites: returns HashMap of index (as string) -> cloned Value +pub fn decode_named_composite(value: &Value) -> Result> { + let mut result = HashMap::new(); + + // Inspect the ValueDef to determine if it's a composite + match &value.value { + ValueDef::Composite(composite) => match composite { + Composite::Named(fields) => { + // Named composite: extract all field names and values + for (name, val) in fields { + result.insert(name.clone(), val.clone()); + } + } + Composite::Unnamed(values) => { + // Unnamed composite: use index as key + for (idx, val) in values.iter().enumerate() { + result.insert(idx.to_string(), val.clone()); + } + } + }, + ValueDef::Variant(variant) => { + // For variants, extract the inner composite values + match &variant.values { + Composite::Named(fields) => { + for (name, val) in fields { + result.insert(name.clone(), val.clone()); + } + } + Composite::Unnamed(values) => { + for (idx, val) in values.iter().enumerate() { + result.insert(idx.to_string(), val.clone()); + } + } + } + } + _ => { + // Not a composite or variant, return empty map + // This is not an error - the caller may expect this for non-composite values + } + } + + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode_prometheus_info() { + // Create a test Value representing PrometheusInfo + // Fields: block (u64), version (u32), ip (u128), port (u16), ip_type (u8) + // Using unnamed composite as that's how subxt returns it + let prometheus_value = Value::unnamed_composite([ + Value::u128(12345), // block + Value::u128(1), // version + Value::u128(2130706433), // ip: 127.0.0.1 as u128 + Value::u128(9933), // port + Value::u128(4), // ip_type: IPv4 + ]); + + let result = decode_prometheus_info(&prometheus_value); + assert!(result.is_ok(), "Failed to decode PrometheusInfo: {:?}", result); + + let info = result.unwrap(); + assert_eq!(info.block, 12345); + assert_eq!(info.version, 1); + assert_eq!(info.port, 9933); + assert_eq!(info.ip_type, 4); + } + + #[test] + fn test_decode_axon_info() { + // Create a test Value representing AxonInfo + // Fields: block, version, ip, port, ip_type, protocol, placeholder1, placeholder2 + let axon_value = Value::unnamed_composite([ + Value::u128(54321), // block + Value::u128(2), // version + Value::u128(2130706433), // ip: 127.0.0.1 as u128 + Value::u128(8080), // port + Value::u128(4), // ip_type: IPv4 + Value::u128(1), // protocol + Value::u128(0), // placeholder1 + Value::u128(0), // placeholder2 + ]); + + let result = decode_axon_info(&axon_value); + assert!(result.is_ok(), "Failed to decode AxonInfo: {:?}", result); + + let info = result.unwrap(); + assert_eq!(info.block, 54321); + assert_eq!(info.version, 2); + assert_eq!(info.port, 8080); + assert_eq!(info.ip_type, 4); + assert_eq!(info.protocol, 1); + assert_eq!(info.placeholder1, 0); + assert_eq!(info.placeholder2, 0); + } + + #[test] + fn test_decode_identity_map_with_named_fields() { + // Create a test Value with named identity fields + let identity_value = Value::named_composite([ + ("name", Value::string("TestValidator")), + ("web", Value::string("https://example.com")), + ("email", Value::string("test@example.com")), + ]); + + let result = decode_identity_map(&identity_value); + assert!(result.is_ok(), "Failed to decode identity map: {:?}", result); + + let map = result.unwrap(); + assert_eq!(map.get("name"), Some(&"TestValidator".to_string())); + assert_eq!(map.get("web"), Some(&"https://example.com".to_string())); + assert_eq!(map.get("email"), Some(&"test@example.com".to_string())); + } + + #[test] + fn test_decode_identity_map_empty() { + // Test with a primitive value (should return empty map) + let primitive_value = Value::u128(42); + + let result = decode_identity_map(&primitive_value); + assert!(result.is_ok()); + + let map = result.unwrap(); + assert!(map.is_empty()); + } + + #[test] + fn test_decode_named_composite_with_named_fields() { + // Create a named composite value + let composite = Value::named_composite([ + ("foo", Value::u128(123)), + ("bar", Value::string("hello")), + ("baz", Value::bool(true)), + ]); + + let result = decode_named_composite(&composite); + assert!(result.is_ok(), "Failed to decode named composite: {:?}", result); + + let map = result.unwrap(); + assert_eq!(map.len(), 3); + assert!(map.contains_key("foo")); + assert!(map.contains_key("bar")); + assert!(map.contains_key("baz")); + + // Verify field values using as_u128, as_str, as_bool + assert_eq!(map.get("foo").and_then(|v| v.as_u128()), Some(123)); + assert_eq!(map.get("bar").and_then(|v| v.as_str()), Some("hello")); + assert_eq!(map.get("baz").and_then(|v| v.as_bool()), Some(true)); + } + + #[test] + fn test_decode_named_composite_with_unnamed_fields() { + // Create an unnamed composite value + let composite = Value::unnamed_composite([ + Value::u128(1), + Value::u128(2), + Value::u128(3), + ]); + + let result = decode_named_composite(&composite); + assert!(result.is_ok()); + + let map = result.unwrap(); + assert_eq!(map.len(), 3); + + // Unnamed composites use index as key + assert!(map.contains_key("0")); + assert!(map.contains_key("1")); + assert!(map.contains_key("2")); + + assert_eq!(map.get("0").and_then(|v| v.as_u128()), Some(1)); + assert_eq!(map.get("1").and_then(|v| v.as_u128()), Some(2)); + assert_eq!(map.get("2").and_then(|v| v.as_u128()), Some(3)); + } + + #[test] + fn test_decode_named_composite_with_primitive() { + // Test with a primitive value (should return empty map) + let primitive = Value::u128(42); + + let result = decode_named_composite(&primitive); + assert!(result.is_ok()); + + let map = result.unwrap(); + assert!(map.is_empty()); + } + + #[test] + fn test_decode_named_composite_with_variant() { + // Create a variant value with named fields + let variant = Value::named_variant("SomeVariant", [ + ("field1", Value::u128(100)), + ("field2", Value::string("variant_data")), + ]); + + let result = decode_named_composite(&variant); + assert!(result.is_ok()); + + let map = result.unwrap(); + assert_eq!(map.len(), 2); + assert!(map.contains_key("field1")); + assert!(map.contains_key("field2")); + } + + #[test] + fn test_static_regex_reuse() { + // Call get_identity_regex multiple times to ensure it's properly cached + let re1 = get_identity_regex(); + let re2 = get_identity_regex(); + + // Both should point to the same regex instance + assert!(std::ptr::eq(re1, re2), "Regex should be cached via OnceLock"); + } } diff --git a/src/utils/decoders/mod.rs b/src/utils/decoders/mod.rs index 4705102..bb08256 100644 --- a/src/utils/decoders/mod.rs +++ b/src/utils/decoders/mod.rs @@ -2,10 +2,11 @@ pub mod composite; pub mod fixed; pub mod primitive; -mod utils; +pub mod utils; pub mod vec; pub use composite::*; pub use fixed::*; pub use primitive::*; +pub use utils::*; pub use vec::*; diff --git a/src/utils/decoders/vec.rs b/src/utils/decoders/vec.rs index 9013e76..cb3f9fe 100644 --- a/src/utils/decoders/vec.rs +++ b/src/utils/decoders/vec.rs @@ -1,6 +1,7 @@ use anyhow::Result; use sp_core::crypto::AccountId32; use subxt::dynamic::Value; +use subxt::ext::scale_value::{Composite, ValueDef}; /// Helper to get the length of a numeric tag (U16, U32, U64, U128) fn get_tag_len(s: &str, pos: usize) -> usize { @@ -13,16 +14,44 @@ fn get_tag_len(s: &str, pos: usize) -> usize { /// Decode a Vec from a Value /// Returns empty Vec if value cannot be decoded as a vector -pub fn decode_vec(value: &Value, _decoder: F) -> Result> +pub fn decode_vec(value: &Value, decoder: F) -> Result> where F: Fn(&Value) -> Result, { - // Parse from debug representation - let s = format!("{:?}", value); + let mut results = Vec::new(); + + // Inspect the ValueDef to extract values from composite/sequence + match &value.value { + ValueDef::Composite(composite) => { + // Handle both named and unnamed composites + let values: Vec<&Value> = match composite { + Composite::Named(fields) => fields.iter().map(|(_, v)| v).collect(), + Composite::Unnamed(vals) => vals.iter().collect(), + }; + for inner_value in values { + if let Ok(decoded) = decoder(inner_value) { + results.push(decoded); + } + } + } + ValueDef::Variant(variant) => { + // For variants, extract inner composite values + let values: Vec<&Value> = match &variant.values { + Composite::Named(fields) => fields.iter().map(|(_, v)| v).collect(), + Composite::Unnamed(vals) => vals.iter().collect(), + }; + for inner_value in values { + if let Ok(decoded) = decoder(inner_value) { + results.push(decoded); + } + } + } + _ => { + // Not a composite or variant, return empty vec + } + } - // Check if it looks like a vector/sequence - return empty vec regardless - let _ = s.contains("Composite(Unnamed([") || s.contains("Sequence(["); - Ok(Vec::new()) + Ok(results) } /// Decode a vector of u16 from Value diff --git a/src/utils/ss58.rs b/src/utils/ss58.rs index 134da1d..a6ca655 100644 --- a/src/utils/ss58.rs +++ b/src/utils/ss58.rs @@ -1,10 +1,29 @@ use anyhow::Result; use sp_core::crypto::{AccountId32, Ss58AddressFormat, Ss58Codec}; +use sp_core::sr25519; use std::str::FromStr; /// SS58 format constant for Bittensor (42 = "bt") pub const SS58_FORMAT: u16 = 42; +/// Trait for converting types to SS58 address format +pub trait AccountId32ToSS58 { + /// Convert to SS58 address string + fn to_ss58(&self) -> String; +} + +impl AccountId32ToSS58 for AccountId32 { + fn to_ss58(&self) -> String { + encode_ss58(self) + } +} + +impl AccountId32ToSS58 for sr25519::Public { + fn to_ss58(&self) -> String { + self.to_ss58check_with_version(Ss58AddressFormat::custom(SS58_FORMAT)) + } +} + /// Encode AccountId32 to SS58 string pub fn encode_ss58(account: &AccountId32) -> String { account.to_ss58check_with_version(Ss58AddressFormat::custom(SS58_FORMAT)) diff --git a/src/utils/weights.rs b/src/utils/weights.rs index 53a1417..8b56dce 100644 --- a/src/utils/weights.rs +++ b/src/utils/weights.rs @@ -47,7 +47,10 @@ pub fn normalize_weights(uids: &[u64], weights: &[f32]) -> Result<(Vec, Vec for (uid, val) in uids.iter().zip(weight_vals.iter()) { if *val > 0 { // Convert uid from u64 to u16 (Subtensor expects Vec) - filtered_uids.push(*uid as u16); + let uid_u16 = u16::try_from(*uid).map_err(|_| { + anyhow::anyhow!("UID {} exceeds u16 max value {}", uid, u16::MAX) + })?; + filtered_uids.push(uid_u16); filtered_vals.push(*val); } } diff --git a/src/validator/mod.rs b/src/validator/mod.rs index 76e3ba7..c965cf2 100644 --- a/src/validator/mod.rs +++ b/src/validator/mod.rs @@ -1,8 +1,10 @@ pub mod children; pub mod liquidity; pub mod mechanism; +pub mod proxy; pub mod registration; pub mod root; +pub mod senate; pub mod serving; pub mod staking; pub mod take; @@ -17,9 +19,11 @@ pub use crate::queries::stakes::get_stake; pub use children::*; pub use liquidity::*; pub use mechanism::*; +pub use proxy::*; pub use registration::{is_registered, register}; pub use root::*; pub use serving::{serve_axon, serve_axon_tls}; pub use take::*; pub use transfer::{transfer, transfer_stake}; pub use utility::*; +pub use senate::*; diff --git a/src/validator/proxy.rs b/src/validator/proxy.rs new file mode 100644 index 0000000..0f228b1 --- /dev/null +++ b/src/validator/proxy.rs @@ -0,0 +1,802 @@ +//! Proxy account operations for Bittensor +//! Allows delegating permissions to other accounts + +use crate::chain::{BittensorClient, BittensorSigner, ExtrinsicWait}; +use crate::errors::{BittensorError, BittensorResult, ChainQueryError, ExtrinsicError}; +use crate::utils::decoders::{decode_account_id32, decode_u128}; +use parity_scale_codec::{Decode, Encode}; +use sp_core::crypto::AccountId32; +use subxt::dynamic::Value; + +const PROXY_MODULE: &str = "Proxy"; + +/// Proxy types for Bittensor +#[derive(Debug, Clone, Copy, PartialEq, Eq, Encode, Decode, Default)] +#[repr(u8)] +pub enum ProxyType { + /// Full permissions + #[default] + Any = 0, + /// Non-transfer related permissions + NonTransfer = 1, + /// Governance related permissions + Governance = 2, + /// Staking related permissions + Staking = 3, + /// Registration related permissions + Registration = 4, + /// Transfer related permissions (like SudoUncheckedSetBalance) + Transfer = 5, + /// Subnet owner specific permissions + Owner = 6, + /// Non-critical validator permissions + NonCritical = 7, + /// Triumvirate/Senate permissions + Triumvirate = 8, + /// Subnet-related permissions + Subnet = 9, + /// Childkey permissions + Childkey = 10, + /// Senate permissions + Senate = 11, +} + +impl ProxyType { + /// Convert proxy type to Value for extrinsic submission + fn to_value(self) -> Value { + let variant_name = match self { + ProxyType::Any => "Any", + ProxyType::NonTransfer => "NonTransfer", + ProxyType::Governance => "Governance", + ProxyType::Staking => "Staking", + ProxyType::Registration => "Registration", + ProxyType::Transfer => "Transfer", + ProxyType::Owner => "Owner", + ProxyType::NonCritical => "NonCritical", + ProxyType::Triumvirate => "Triumvirate", + ProxyType::Subnet => "Subnet", + ProxyType::Childkey => "Childkey", + ProxyType::Senate => "Senate", + }; + Value::named_variant(variant_name, Vec::<(&str, Value)>::new()) + } + + /// Try to parse proxy type from a string representation + fn from_str(s: &str) -> Option { + match s { + "Any" => Some(ProxyType::Any), + "NonTransfer" => Some(ProxyType::NonTransfer), + "Governance" => Some(ProxyType::Governance), + "Staking" => Some(ProxyType::Staking), + "Registration" => Some(ProxyType::Registration), + "Transfer" => Some(ProxyType::Transfer), + "Owner" => Some(ProxyType::Owner), + "NonCritical" => Some(ProxyType::NonCritical), + "Triumvirate" => Some(ProxyType::Triumvirate), + "Subnet" => Some(ProxyType::Subnet), + "Childkey" => Some(ProxyType::Childkey), + "Senate" => Some(ProxyType::Senate), + _ => None, + } + } +} + +/// Proxy account information +#[derive(Debug, Clone)] +pub struct ProxyInfo { + /// The delegate account that has been granted proxy permissions + pub delegate: AccountId32, + /// The type of proxy permissions granted + pub proxy_type: ProxyType, + /// Delay in blocks before the proxy can execute calls + pub delay: u32, +} + +// ============================================================================= +// Proxy Management +// ============================================================================= + +/// Add a proxy account +/// +/// Allows the signer to grant proxy permissions to the delegate account. +/// The delegate can then execute calls on behalf of the signer. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `signer` - The account granting proxy permissions +/// * `delegate` - The account receiving proxy permissions +/// * `proxy_type` - The type of proxy permissions to grant +/// * `delay` - Delay in blocks before the proxy can execute calls +/// * `wait_for` - How long to wait for the extrinsic +/// +/// # Returns +/// The transaction hash on success +pub async fn add_proxy( + client: &BittensorClient, + signer: &BittensorSigner, + delegate: &AccountId32, + proxy_type: ProxyType, + delay: u32, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let args = vec![ + Value::from_bytes(delegate.encode()), + proxy_type.to_value(), + Value::u128(delay as u128), + ]; + + client + .submit_extrinsic(PROXY_MODULE, "add_proxy", args, signer, wait_for) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to add proxy: {}", e), + PROXY_MODULE, + "add_proxy", + )) + }) +} + +/// Remove a proxy account +/// +/// Revokes proxy permissions from the delegate account. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `signer` - The account revoking proxy permissions +/// * `delegate` - The account losing proxy permissions +/// * `proxy_type` - The type of proxy permissions to revoke +/// * `delay` - The delay that was set when adding the proxy +/// * `wait_for` - How long to wait for the extrinsic +/// +/// # Returns +/// The transaction hash on success +pub async fn remove_proxy( + client: &BittensorClient, + signer: &BittensorSigner, + delegate: &AccountId32, + proxy_type: ProxyType, + delay: u32, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let args = vec![ + Value::from_bytes(delegate.encode()), + proxy_type.to_value(), + Value::u128(delay as u128), + ]; + + client + .submit_extrinsic(PROXY_MODULE, "remove_proxy", args, signer, wait_for) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to remove proxy: {}", e), + PROXY_MODULE, + "remove_proxy", + )) + }) +} + +/// Remove all proxies +/// +/// Revokes all proxy permissions granted by the signer. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `signer` - The account revoking all proxy permissions +/// * `wait_for` - How long to wait for the extrinsic +/// +/// # Returns +/// The transaction hash on success +pub async fn remove_proxies( + client: &BittensorClient, + signer: &BittensorSigner, + wait_for: ExtrinsicWait, +) -> BittensorResult { + client + .submit_extrinsic(PROXY_MODULE, "remove_proxies", Vec::new(), signer, wait_for) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to remove all proxies: {}", e), + PROXY_MODULE, + "remove_proxies", + )) + }) +} + +/// Execute call as proxy +/// +/// Allows a proxy account to execute a call on behalf of the real account. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `signer` - The proxy account executing the call +/// * `real` - The account on whose behalf the call is being made +/// * `force_proxy_type` - Optional: force a specific proxy type check +/// * `call` - The encoded call data to execute +/// * `wait_for` - How long to wait for the extrinsic +/// +/// # Returns +/// The transaction hash on success +pub async fn proxy( + client: &BittensorClient, + signer: &BittensorSigner, + real: &AccountId32, + force_proxy_type: Option, + call: Vec, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let force_proxy_type_value = match force_proxy_type { + Some(pt) => Value::named_variant("Some", [("value", pt.to_value())]), + None => Value::named_variant("None", Vec::<(&str, Value)>::new()), + }; + + let args = vec![ + Value::from_bytes(real.encode()), + force_proxy_type_value, + Value::from_bytes(&call), + ]; + + client + .submit_extrinsic(PROXY_MODULE, "proxy", args, signer, wait_for) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to execute proxy call: {}", e), + PROXY_MODULE, + "proxy", + )) + }) +} + +/// Create a pure (anonymous) proxy +/// +/// Creates a new account that can only be controlled by the spawner via proxy calls. +/// This is useful for creating accounts that cannot directly sign transactions. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `signer` - The spawner account that will control the pure proxy +/// * `proxy_type` - The type of proxy permissions for the pure proxy +/// * `delay` - Delay in blocks before the proxy can execute calls +/// * `index` - A disambiguation index to allow creating multiple pure proxies with the same parameters +/// * `wait_for` - How long to wait for the extrinsic +/// +/// # Returns +/// The transaction hash on success +pub async fn create_pure( + client: &BittensorClient, + signer: &BittensorSigner, + proxy_type: ProxyType, + delay: u32, + index: u16, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let args = vec![ + proxy_type.to_value(), + Value::u128(delay as u128), + Value::u128(index as u128), + ]; + + client + .submit_extrinsic(PROXY_MODULE, "create_pure", args, signer, wait_for) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to create pure proxy: {}", e), + PROXY_MODULE, + "create_pure", + )) + }) +} + +/// Kill a pure proxy +/// +/// Removes a pure proxy account that was created by the spawner. +/// This can only be called by a proxy of the pure account. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `signer` - The proxy account that controls the pure proxy +/// * `spawner` - The account that originally created the pure proxy +/// * `proxy_type` - The proxy type used when creating the pure proxy +/// * `index` - The disambiguation index used when creating the pure proxy +/// * `height` - The block height at which the pure proxy was created +/// * `ext_index` - The extrinsic index in that block +/// * `wait_for` - How long to wait for the extrinsic +/// +/// # Returns +/// The transaction hash on success +#[allow(clippy::too_many_arguments)] +pub async fn kill_pure( + client: &BittensorClient, + signer: &BittensorSigner, + spawner: &AccountId32, + proxy_type: ProxyType, + index: u16, + height: u32, + ext_index: u32, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let args = vec![ + Value::from_bytes(spawner.encode()), + proxy_type.to_value(), + Value::u128(index as u128), + Value::u128(height as u128), + Value::u128(ext_index as u128), + ]; + + client + .submit_extrinsic(PROXY_MODULE, "kill_pure", args, signer, wait_for) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to kill pure proxy: {}", e), + PROXY_MODULE, + "kill_pure", + )) + }) +} + +// ============================================================================= +// Proxy Queries +// ============================================================================= + +/// Get all proxies for an account +/// +/// Returns a list of all proxy definitions for the given account. +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `account` - The account to query proxies for +/// +/// # Returns +/// A list of ProxyInfo containing delegate, proxy type, and delay +pub async fn get_proxies( + client: &BittensorClient, + account: &AccountId32, +) -> BittensorResult> { + let result = client + .storage_with_keys( + PROXY_MODULE, + "Proxies", + vec![Value::from_bytes(account.encode())], + ) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query proxies: {}", e), + PROXY_MODULE, + "Proxies", + )) + })?; + + match result { + Some(value) => parse_proxies_storage(&value), + None => Ok(Vec::new()), + } +} + +/// Check if an account is a proxy for another +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `real` - The account that may have granted proxy permissions +/// * `delegate` - The account to check if it has proxy permissions +/// * `proxy_type` - Optional: check for a specific proxy type +/// +/// # Returns +/// True if the delegate is a proxy for the real account +pub async fn is_proxy( + client: &BittensorClient, + real: &AccountId32, + delegate: &AccountId32, + proxy_type: Option, +) -> BittensorResult { + let proxies = get_proxies(client, real).await?; + + for proxy_info in proxies { + if proxy_info.delegate == *delegate { + match proxy_type { + Some(pt) => { + if proxy_info.proxy_type == pt { + return Ok(true); + } + } + None => return Ok(true), + } + } + } + + Ok(false) +} + +/// Parse the Proxies storage value into a list of ProxyInfo +/// +/// The Proxies storage returns a tuple of (BoundedVec, Balance) +/// We parse this by examining the debug representation of the Value +fn parse_proxies_storage(value: &Value) -> BittensorResult> { + let mut proxies = Vec::new(); + let value_str = format!("{:?}", value); + + // Parse proxy definitions from the storage value + // Each ProxyDefinition has format: { delegate: AccountId, proxy_type: ProxyType, delay: u32 } + // The storage format is a tuple: ((proxy_list), deposit) + + // Extract account IDs from the value string + let account_ids = extract_account_ids_from_debug(&value_str); + + // Extract proxy types from the value string + let proxy_types = extract_proxy_types_from_debug(&value_str); + + // Extract delays from the value string + let delays = extract_delays_from_debug(&value_str); + + // Match up the extracted values into ProxyInfo structs + // The first account ID in the tuple is usually followed by proxy type and delay + let num_proxies = account_ids + .len() + .min(proxy_types.len()) + .min(delays.len()); + + for i in 0..num_proxies { + proxies.push(ProxyInfo { + delegate: account_ids[i].clone(), + proxy_type: proxy_types[i], + delay: delays[i], + }); + } + + Ok(proxies) +} + +/// Extract AccountId32 values from a debug string representation +fn extract_account_ids_from_debug(s: &str) -> Vec { + let mut accounts = Vec::new(); + + // Look for 0x-prefixed hex strings that are 64 characters (32 bytes) + let mut search_start = 0; + while let Some(start) = s[search_start..].find("0x") { + let abs_start = search_start + start; + let hex_start = abs_start + 2; + + if hex_start < s.len() { + let hex_chars: String = s[hex_start..] + .chars() + .take_while(|c| c.is_ascii_hexdigit()) + .take(64) + .collect(); + + if hex_chars.len() == 64 { + if let Ok(bytes) = hex::decode(&hex_chars) { + if bytes.len() == 32 { + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + accounts.push(AccountId32::from(arr)); + } + } + } + + search_start = hex_start + hex_chars.len(); + } else { + break; + } + } + + accounts +} + +/// Extract ProxyType values from a debug string representation +fn extract_proxy_types_from_debug(s: &str) -> Vec { + let mut types = Vec::new(); + + // Look for variant names in the debug output + let variant_names = [ + "Any", + "NonTransfer", + "Governance", + "Staking", + "Registration", + "Transfer", + "Owner", + "NonCritical", + "Triumvirate", + "Subnet", + "Childkey", + "Senate", + ]; + + // Find all variant patterns like Variant("TypeName" or name: "TypeName" + for variant_name in variant_names { + let mut search_start = 0; + while search_start < s.len() { + // Check for Variant("TypeName" pattern + let pattern1 = format!("Variant(\"{}", variant_name); + let pattern2 = format!("\"{}\"", variant_name); + + if let Some(pos) = s[search_start..].find(&pattern1) { + if let Some(pt) = ProxyType::from_str(variant_name) { + types.push(pt); + } + search_start = search_start + pos + pattern1.len(); + } else if let Some(pos) = s[search_start..].find(&pattern2) { + // Only count if it looks like a variant context + let context_start = pos.saturating_sub(10); + let context = &s[search_start + context_start..search_start + pos + pattern2.len()]; + if context.contains("Variant") || context.contains("name:") { + if let Some(pt) = ProxyType::from_str(variant_name) { + types.push(pt); + } + } + search_start = search_start + pos + pattern2.len(); + } else { + break; + } + } + } + + // If no variants found, try numeric parsing + if types.is_empty() { + // Look for U8/U16/U32 patterns that could be proxy type indices + let mut search_start = 0; + while let Some(pos) = s[search_start..].find("U8(") { + let abs_pos = search_start + pos; + let num_start = abs_pos + 3; + if let Some(end) = s[num_start..].find(')') { + if let Ok(num) = s[num_start..num_start + end].trim().parse::() { + if num <= 11 { + if let Some(pt) = proxy_type_from_u8(num) { + types.push(pt); + } + } + } + } + search_start = num_start; + } + } + + types +} + +/// Extract delay values (u32) from a debug string representation +fn extract_delays_from_debug(s: &str) -> Vec { + let mut delays = Vec::new(); + + // Look for U32 patterns + let mut search_start = 0; + while let Some(pos) = s[search_start..].find("U32(") { + let abs_pos = search_start + pos; + let num_start = abs_pos + 4; + if let Some(end) = s[num_start..].find(')') { + if let Ok(num) = s[num_start..num_start + end].trim().parse::() { + delays.push(num); + } + } + search_start = num_start; + } + + // Also try U64 in case delays are stored as larger integers + if delays.is_empty() { + search_start = 0; + while let Some(pos) = s[search_start..].find("U64(") { + let abs_pos = search_start + pos; + let num_start = abs_pos + 4; + if let Some(end) = s[num_start..].find(')') { + if let Ok(num) = s[num_start..num_start + end].trim().parse::() { + delays.push(num as u32); + } + } + search_start = num_start; + } + } + + delays +} + +/// Convert a u8 value to ProxyType +fn proxy_type_from_u8(value: u8) -> Option { + match value { + 0 => Some(ProxyType::Any), + 1 => Some(ProxyType::NonTransfer), + 2 => Some(ProxyType::Governance), + 3 => Some(ProxyType::Staking), + 4 => Some(ProxyType::Registration), + 5 => Some(ProxyType::Transfer), + 6 => Some(ProxyType::Owner), + 7 => Some(ProxyType::NonCritical), + 8 => Some(ProxyType::Triumvirate), + 9 => Some(ProxyType::Subnet), + 10 => Some(ProxyType::Childkey), + 11 => Some(ProxyType::Senate), + _ => None, + } +} + +/// Parse a ProxyType from a Value using debug string parsing +#[allow(dead_code)] +fn parse_proxy_type(value: &Value) -> Option { + let s = format!("{:?}", value); + + // Try variant pattern first + for (variant_name, pt) in [ + ("Any", ProxyType::Any), + ("NonTransfer", ProxyType::NonTransfer), + ("Governance", ProxyType::Governance), + ("Staking", ProxyType::Staking), + ("Registration", ProxyType::Registration), + ("Transfer", ProxyType::Transfer), + ("Owner", ProxyType::Owner), + ("NonCritical", ProxyType::NonCritical), + ("Triumvirate", ProxyType::Triumvirate), + ("Subnet", ProxyType::Subnet), + ("Childkey", ProxyType::Childkey), + ("Senate", ProxyType::Senate), + ] { + if s.contains(&format!("\"{}\"", variant_name)) { + return Some(pt); + } + } + + // Try numeric parsing + if let Ok(num) = decode_u128(value) { + return proxy_type_from_u8(num as u8); + } + + None +} + +/// Parse a u32 from a Value +#[allow(dead_code)] +fn parse_u32(value: &Value) -> Option { + decode_u128(value).ok().map(|v| v as u32) +} + +/// Parse an AccountId32 from a Value +#[allow(dead_code)] +fn parse_account_id(value: &Value) -> Option { + decode_account_id32(value).ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_proxy_type_encoding() { + assert_eq!(ProxyType::Any as u8, 0); + assert_eq!(ProxyType::Staking as u8, 3); + assert_eq!(ProxyType::Senate as u8, 11); + } + + #[test] + fn test_proxy_type_default() { + assert_eq!(ProxyType::default(), ProxyType::Any); + } + + #[test] + fn test_proxy_type_from_str() { + assert_eq!(ProxyType::from_str("Any"), Some(ProxyType::Any)); + assert_eq!(ProxyType::from_str("Staking"), Some(ProxyType::Staking)); + assert_eq!(ProxyType::from_str("Senate"), Some(ProxyType::Senate)); + assert_eq!(ProxyType::from_str("Invalid"), None); + assert_eq!(ProxyType::from_str(""), None); + } + + #[test] + fn test_proxy_type_to_value() { + let value = ProxyType::Any.to_value(); + let debug_str = format!("{:?}", value); + assert!(debug_str.contains("Any")); + + let value = ProxyType::Staking.to_value(); + let debug_str = format!("{:?}", value); + assert!(debug_str.contains("Staking")); + } + + #[test] + fn test_proxy_info_debug() { + let account_bytes = [1u8; 32]; + let info = ProxyInfo { + delegate: AccountId32::from(account_bytes), + proxy_type: ProxyType::Staking, + delay: 100, + }; + + let debug_str = format!("{:?}", info); + assert!(debug_str.contains("Staking")); + assert!(debug_str.contains("100")); + } + + #[test] + fn test_proxy_info_clone() { + let account_bytes = [2u8; 32]; + let info = ProxyInfo { + delegate: AccountId32::from(account_bytes), + proxy_type: ProxyType::Governance, + delay: 50, + }; + + let cloned = info.clone(); + assert_eq!(cloned.delegate, info.delegate); + assert_eq!(cloned.proxy_type, info.proxy_type); + assert_eq!(cloned.delay, info.delay); + } + + #[test] + fn test_proxy_type_from_u8() { + assert_eq!(proxy_type_from_u8(0), Some(ProxyType::Any)); + assert_eq!(proxy_type_from_u8(3), Some(ProxyType::Staking)); + assert_eq!(proxy_type_from_u8(11), Some(ProxyType::Senate)); + assert_eq!(proxy_type_from_u8(255), None); + } + + #[test] + fn test_extract_account_ids_from_debug() { + // Test with a known hex AccountId + let hex_account = "0x0101010101010101010101010101010101010101010101010101010101010101"; + let test_str = format!("Some data {} more data", hex_account); + let accounts = extract_account_ids_from_debug(&test_str); + assert_eq!(accounts.len(), 1); + assert_eq!(accounts[0], AccountId32::from([1u8; 32])); + } + + #[test] + fn test_extract_delays_from_debug() { + let test_str = "Composite { values: [U32(100), U32(200)] }"; + let delays = extract_delays_from_debug(test_str); + assert_eq!(delays.len(), 2); + assert_eq!(delays[0], 100); + assert_eq!(delays[1], 200); + } + + #[test] + fn test_proxy_type_equality() { + assert_eq!(ProxyType::Any, ProxyType::Any); + assert_ne!(ProxyType::Any, ProxyType::Staking); + assert_eq!(ProxyType::Senate, ProxyType::Senate); + } + + #[test] + fn test_all_proxy_types_have_from_str() { + let types = [ + ("Any", ProxyType::Any), + ("NonTransfer", ProxyType::NonTransfer), + ("Governance", ProxyType::Governance), + ("Staking", ProxyType::Staking), + ("Registration", ProxyType::Registration), + ("Transfer", ProxyType::Transfer), + ("Owner", ProxyType::Owner), + ("NonCritical", ProxyType::NonCritical), + ("Triumvirate", ProxyType::Triumvirate), + ("Subnet", ProxyType::Subnet), + ("Childkey", ProxyType::Childkey), + ("Senate", ProxyType::Senate), + ]; + + for (name, expected) in types { + assert_eq!( + ProxyType::from_str(name), + Some(expected), + "Failed for {}", + name + ); + } + } + + #[test] + fn test_all_proxy_types_u8_values() { + assert_eq!(ProxyType::Any as u8, 0); + assert_eq!(ProxyType::NonTransfer as u8, 1); + assert_eq!(ProxyType::Governance as u8, 2); + assert_eq!(ProxyType::Staking as u8, 3); + assert_eq!(ProxyType::Registration as u8, 4); + assert_eq!(ProxyType::Transfer as u8, 5); + assert_eq!(ProxyType::Owner as u8, 6); + assert_eq!(ProxyType::NonCritical as u8, 7); + assert_eq!(ProxyType::Triumvirate as u8, 8); + assert_eq!(ProxyType::Subnet as u8, 9); + assert_eq!(ProxyType::Childkey as u8, 10); + assert_eq!(ProxyType::Senate as u8, 11); + } +} diff --git a/src/validator/senate.rs b/src/validator/senate.rs new file mode 100644 index 0000000..f45cd47 --- /dev/null +++ b/src/validator/senate.rs @@ -0,0 +1,938 @@ +//! Senate and governance operations for Bittensor +//! Implements senate registration, voting, and membership management + +use crate::chain::{BittensorClient, BittensorSigner, ExtrinsicWait}; +use crate::errors::{BittensorError, BittensorResult, ChainQueryError, ExtrinsicError}; +use crate::utils::decoders::decode_vec_account_id32; + +use sp_core::crypto::AccountId32; +use subxt::dynamic::Value; + +const SUBTENSOR_MODULE: &str = "SubtensorModule"; +const SENATE_MODULE: &str = "SenateMembers"; +const TRIUMVIRATE_MODULE: &str = "Triumvirate"; + +// ============================================================================= +// Proposal Data Structures +// ============================================================================= + +/// Proposal data structure for governance proposals +#[derive(Debug, Clone)] +pub struct Proposal { + /// The hash of the proposal + pub hash: [u8; 32], + /// The index of the proposal + pub index: u32, + /// The account that proposed this (None if triumvirate prime couldn't be determined) + pub proposer: Option, + /// The encoded call data for the proposal + pub call_data: Vec, + /// The vote threshold required to pass + pub threshold: u32, + /// List of accounts that voted in favor + pub ayes: Vec, + /// List of accounts that voted against + pub nays: Vec, + /// The block number at which voting ends + pub end: u64, +} + +/// Vote data for a specific proposal +#[derive(Debug, Clone)] +pub struct VoteData { + /// The proposal index + pub index: u32, + /// The vote threshold required to pass + pub threshold: u32, + /// List of accounts that voted in favor + pub ayes: Vec, + /// List of accounts that voted against + pub nays: Vec, + /// The block number at which voting ends + pub end: u64, +} + +// ============================================================================= +// Senate Registration +// ============================================================================= + +/// Register as a senate member +/// Requires being a delegate with sufficient stake +pub async fn register_senate( + client: &BittensorClient, + signer: &BittensorSigner, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let args: Vec = vec![]; + + let tx_hash = client + .submit_extrinsic( + SUBTENSOR_MODULE, + "join_senate", + args, + signer, + wait_for, + ) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to register as senate member: {}", e), + SUBTENSOR_MODULE, + "join_senate", + )) + })?; + + Ok(tx_hash) +} + +/// Leave the senate +pub async fn leave_senate( + client: &BittensorClient, + signer: &BittensorSigner, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let args: Vec = vec![]; + + let tx_hash = client + .submit_extrinsic( + SUBTENSOR_MODULE, + "leave_senate", + args, + signer, + wait_for, + ) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to leave senate: {}", e), + SUBTENSOR_MODULE, + "leave_senate", + )) + })?; + + Ok(tx_hash) +} + +// ============================================================================= +// Voting +// ============================================================================= + +/// Vote on a proposal +/// +/// # Arguments +/// * `client` - The Bittensor client +/// * `signer` - The signer (must be a senate member) +/// * `proposal_hash` - The 32-byte hash of the proposal +/// * `proposal_index` - The index of the proposal +/// * `approve` - Whether to vote in favor (true) or against (false) +/// * `wait_for` - How long to wait for the transaction +pub async fn vote( + client: &BittensorClient, + signer: &BittensorSigner, + proposal_hash: &[u8; 32], + proposal_index: u32, + approve: bool, + wait_for: ExtrinsicWait, +) -> BittensorResult { + let args = vec![ + Value::from_bytes(proposal_hash), + Value::u128(proposal_index as u128), + Value::bool(approve), + ]; + + let tx_hash = client + .submit_extrinsic( + SUBTENSOR_MODULE, + "vote", + args, + signer, + wait_for, + ) + .await + .map_err(|e| { + BittensorError::Extrinsic(ExtrinsicError::with_call( + format!("Failed to vote on proposal: {}", e), + SUBTENSOR_MODULE, + "vote", + )) + })?; + + Ok(tx_hash) +} + +// ============================================================================= +// Senate Queries +// ============================================================================= + +/// Check if an account is a senate member +pub async fn is_senate_member( + client: &BittensorClient, + hotkey: &AccountId32, +) -> BittensorResult { + // Query SenateMembers.Members storage + let members_val = client + .storage(SENATE_MODULE, "Members", None) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query senate members: {}", e), + SENATE_MODULE, + "Members", + )) + })?; + + match members_val { + Some(val) => { + let members = decode_vec_account_id32(&val).map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::new(format!( + "Failed to decode senate members: {}", + e + ))) + })?; + Ok(members.contains(hotkey)) + } + None => Ok(false), + } +} + +/// Get all senate members +pub async fn get_senate_members( + client: &BittensorClient, +) -> BittensorResult> { + // Query SenateMembers.Members storage + let members_val = client + .storage(SENATE_MODULE, "Members", None) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query senate members: {}", e), + SENATE_MODULE, + "Members", + )) + })?; + + match members_val { + Some(val) => { + let members = decode_vec_account_id32(&val).map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::new(format!( + "Failed to decode senate members: {}", + e + ))) + })?; + Ok(members) + } + None => Ok(Vec::new()), + } +} + +/// Get proposal data for a specific proposal hash +pub async fn get_proposal( + client: &BittensorClient, + proposal_hash: &[u8; 32], +) -> BittensorResult> { + // Get vote data first + let vote_data = get_vote_data(client, proposal_hash).await?; + + // Query proposal call data from Triumvirate.ProposalOf + let proposal_of_val = client + .storage_with_keys( + TRIUMVIRATE_MODULE, + "ProposalOf", + vec![Value::from_bytes(proposal_hash)], + ) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query proposal data: {}", e), + TRIUMVIRATE_MODULE, + "ProposalOf", + )) + })?; + + // If no vote data, the proposal doesn't exist + let vote_info = match vote_data { + Some(v) => v, + None => return Ok(None), + }; + + // Extract call data from the proposal + let call_data = match &proposal_of_val { + Some(val) => extract_call_data(val), + None => Vec::new(), + }; + + // Get the proposer - returns None if we can't determine it, rather than masking with zeroed account + let proposer = match get_triumvirate_prime(client).await { + Ok(prime) => Some(prime), + Err(e) => { + tracing::warn!("Failed to get triumvirate prime for proposal {:?}: {}", proposal_hash, e); + None + } + }; + + Ok(Some(Proposal { + hash: *proposal_hash, + index: vote_info.index, + proposer, + call_data, + threshold: vote_info.threshold, + ayes: vote_info.ayes, + nays: vote_info.nays, + end: vote_info.end, + })) +} + +/// Get all active proposals +pub async fn get_proposals( + client: &BittensorClient, +) -> BittensorResult> { + // Query Triumvirate.Proposals to get list of proposal hashes + let proposals_val = client + .storage(TRIUMVIRATE_MODULE, "Proposals", None) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query proposals list: {}", e), + TRIUMVIRATE_MODULE, + "Proposals", + )) + })?; + + let proposal_hashes = match proposals_val { + Some(val) => extract_proposal_hashes(&val), + None => return Ok(Vec::new()), + }; + + let mut proposals = Vec::with_capacity(proposal_hashes.len()); + + for hash in proposal_hashes { + if let Some(proposal) = get_proposal(client, &hash).await? { + proposals.push(proposal); + } + } + + Ok(proposals) +} + +/// Get vote data for a proposal +pub async fn get_vote_data( + client: &BittensorClient, + proposal_hash: &[u8; 32], +) -> BittensorResult> { + // Query Triumvirate.Voting storage + let voting_val = client + .storage_with_keys( + TRIUMVIRATE_MODULE, + "Voting", + vec![Value::from_bytes(proposal_hash)], + ) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query vote data: {}", e), + TRIUMVIRATE_MODULE, + "Voting", + )) + })?; + + match voting_val { + Some(val) => { + let s = format!("{:?}", val); + + let index = extract_first_u64_after_key(&s, "index") + .map(|v| v as u32) + .unwrap_or(0); + let threshold = extract_first_u64_after_key(&s, "threshold") + .map(|v| v as u32) + .unwrap_or(0); + let end = extract_first_u64_after_key(&s, "end").unwrap_or(0); + let ayes = extract_accounts_array_after_key(&s, "ayes"); + let nays = extract_accounts_array_after_key(&s, "nays"); + + Ok(Some(VoteData { + index, + threshold, + ayes, + nays, + end, + })) + } + None => Ok(None), + } +} + +/// Get the number of proposals +pub async fn get_proposal_count( + client: &BittensorClient, +) -> BittensorResult { + // Query Triumvirate.ProposalCount storage + let count_val = client + .storage(TRIUMVIRATE_MODULE, "ProposalCount", None) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query proposal count: {}", e), + TRIUMVIRATE_MODULE, + "ProposalCount", + )) + })?; + + match count_val { + Some(val) => { + let s = format!("{:?}", val); + // Try to extract U32 or U64 value + if let Some(count) = extract_u32_from_value(&s) { + Ok(count) + } else { + Ok(0) + } + } + None => Ok(0), + } +} + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Get the Triumvirate Prime (lead proposer) +async fn get_triumvirate_prime(client: &BittensorClient) -> BittensorResult { + let prime_val = client + .storage(TRIUMVIRATE_MODULE, "Prime", None) + .await + .map_err(|e| { + BittensorError::ChainQuery(ChainQueryError::with_storage( + format!("Failed to query triumvirate prime: {}", e), + TRIUMVIRATE_MODULE, + "Prime", + )) + })?; + + match prime_val { + Some(val) => { + let s = format!("{:?}", val); + // Extract AccountId32 from the value + if let Some(hx_pos) = s.find("0x") { + let hex_str: String = s[hx_pos + 2..] + .chars() + .take_while(|c| c.is_ascii_hexdigit()) + .collect(); + if hex_str.len() >= 64 { + if let Ok(bytes) = hex::decode(&hex_str[..64]) { + if bytes.len() == 32 { + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + return Ok(AccountId32::from(arr)); + } + } + } + } + Err(BittensorError::ChainQuery(ChainQueryError::new( + "Failed to decode triumvirate prime", + ))) + } + None => Err(BittensorError::ChainQuery(ChainQueryError::new( + "Triumvirate prime not set", + ))), + } +} + +/// Extract a u32 value from a debug string. +/// +/// # Note on Debug String Parsing +/// +/// This function parses Rust's Debug format output. The Debug format is not stable +/// and may change between versions of subxt. This approach is used because the +/// Value API doesn't provide direct typed access for all storage values. +fn extract_u32_from_value(s: &str) -> Option { + // Try U32( first + if let Some(pos) = s.find("U32(") { + let aft = &s[pos + 4..]; + if let Some(end) = aft.find(')') { + return aft[..end].trim().parse::().ok(); + } + } + // Try U64( next + if let Some(pos) = s.find("U64(") { + let aft = &s[pos + 4..]; + if let Some(end) = aft.find(')') { + return aft[..end].trim().parse::().ok().map(|v| v as u32); + } + } + // Try U128( last + if let Some(pos) = s.find("U128(") { + let aft = &s[pos + 5..]; + if let Some(end) = aft.find(')') { + return aft[..end].trim().parse::().ok().map(|v| v as u32); + } + } + None +} + +/// Extract the first u64 value after a key in a debug string. +/// +/// # Note on Debug String Parsing +/// +/// This function parses Rust's Debug format output. The Debug format is not stable +/// and may change between versions of subxt. Supports U32, U64, and U128 patterns, +/// returning the value as u64 (with potential truncation for U128 values). +fn extract_first_u64_after_key(s: &str, key: &str) -> Option { + if let Some(kp) = s.find(key) { + let subs = &s[kp..]; + + // Find which pattern appears first after the key + let u64_pos = subs.find("U64("); + let u32_pos = subs.find("U32("); + let u128_pos = subs.find("U128("); + + // Collect all found patterns with their positions + let mut candidates: Vec<(usize, &str, usize)> = vec![]; + if let Some(p) = u64_pos { + candidates.push((p, "U64(", 4)); + } + if let Some(p) = u32_pos { + candidates.push((p, "U32(", 4)); + } + if let Some(p) = u128_pos { + candidates.push((p, "U128(", 5)); + } + + // Sort by position to find the first one + candidates.sort_by_key(|c| c.0); + + if let Some((pos, _pattern, skip)) = candidates.first() { + let aft = &subs[pos + skip..]; + if let Some(end) = aft.find(')') { + // Parse as u128 to handle all cases, then convert to u64 + return aft[..end].trim().parse::().ok().map(|v| v as u64); + } + } + } + None +} + +/// Extract AccountId32 array after a key in a debug string. +/// +/// # Note on Debug String Parsing +/// +/// This function parses Rust's Debug format output (via `format!("{:?}", value)`). +/// This is inherently fragile as Debug format is not guaranteed to be stable. +/// However, it's necessary because the Value API from subxt doesn't provide +/// direct access to nested composite fields. The function tries to be defensive +/// by properly tracking bracket nesting depth to correctly identify array boundaries. +/// +/// If the Debug format changes in future versions of subxt, this parser may need updates. +fn extract_accounts_array_after_key(s: &str, key: &str) -> Vec { + let mut accounts = Vec::new(); + if let Some(kp) = s.find(key) { + let subs = &s[kp..]; + // Find the opening bracket of the array + let array_start = match subs.find('[') { + Some(p) => p, + None => return accounts, + }; + let array_content = &subs[array_start..]; + + // Find the matching closing bracket by tracking nesting depth + let mut depth = 0; + let mut end_pos = 0; + for (i, c) in array_content.char_indices() { + match c { + '[' => depth += 1, + ']' => { + depth -= 1; + if depth == 0 { + end_pos = i; + break; + } + } + _ => {} + } + } + + // If we didn't find a matching bracket, use the rest of the string + let bounded = if end_pos > 0 { + &array_content[..=end_pos] + } else { + array_content + }; + + // Now extract accounts only within this bounded array + let mut rem = bounded; + while let Some(pos) = rem.find("0x") { + let hexstr: String = rem[pos + 2..] + .chars() + .take_while(|c| c.is_ascii_hexdigit()) + .collect(); + if hexstr.len() >= 64 { + if let Ok(bytes) = hex::decode(&hexstr[..64]) { + if bytes.len() == 32 { + if let Ok(arr) = <[u8; 32]>::try_from(bytes.as_slice()) { + accounts.push(AccountId32::from(arr)); + } + } + } + } + // Move past this hex string + let advance = pos + 2 + hexstr.len(); + if advance >= rem.len() { + break; + } + rem = &rem[advance..]; + } + } + accounts +} + +/// Extract proposal hashes from the Proposals storage value +fn extract_proposal_hashes(val: &Value) -> Vec<[u8; 32]> { + let s = format!("{:?}", val); + let mut hashes = Vec::new(); + let mut rem = s.as_str(); + + while let Some(pos) = rem.find("0x") { + let hex_str: String = rem[pos + 2..] + .chars() + .take_while(|c| c.is_ascii_hexdigit()) + .collect(); + if hex_str.len() >= 64 { + if let Ok(bytes) = hex::decode(&hex_str[..64]) { + if bytes.len() == 32 { + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + hashes.push(arr); + } + } + } + rem = &rem[pos + 2 + hex_str.len()..]; + } + + hashes +} + +/// Extract call data from a proposal value +fn extract_call_data(val: &Value) -> Vec { + let s = format!("{:?}", val); + // Call data is typically stored as bytes, look for 0x prefixed hex + if let Some(pos) = s.find("0x") { + let hex_str: String = s[pos + 2..] + .chars() + .take_while(|c| c.is_ascii_hexdigit()) + .collect(); + if !hex_str.is_empty() { + if let Ok(bytes) = hex::decode(&hex_str) { + return bytes; + } + } + } + Vec::new() +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vote_data_default() { + let vote_data = VoteData { + index: 0, + threshold: 2, + ayes: vec![], + nays: vec![], + end: 1000, + }; + assert_eq!(vote_data.index, 0); + assert_eq!(vote_data.threshold, 2); + assert!(vote_data.ayes.is_empty()); + assert!(vote_data.nays.is_empty()); + assert_eq!(vote_data.end, 1000); + } + + #[test] + fn test_proposal_creation() { + let hash = [1u8; 32]; + let proposer = AccountId32::from([2u8; 32]); + + let proposal = Proposal { + hash, + index: 5, + proposer: Some(proposer.clone()), + call_data: vec![1, 2, 3, 4], + threshold: 3, + ayes: vec![proposer.clone()], + nays: vec![], + end: 2000, + }; + + assert_eq!(proposal.hash, hash); + assert_eq!(proposal.index, 5); + assert_eq!(proposal.threshold, 3); + assert_eq!(proposal.ayes.len(), 1); + assert!(proposal.nays.is_empty()); + assert_eq!(proposal.call_data, vec![1, 2, 3, 4]); + assert!(proposal.proposer.is_some()); + } + + #[test] + fn test_proposal_with_unknown_proposer() { + let hash = [1u8; 32]; + + let proposal = Proposal { + hash, + index: 5, + proposer: None, // Unknown proposer + call_data: vec![1, 2, 3, 4], + threshold: 3, + ayes: vec![], + nays: vec![], + end: 2000, + }; + + assert!(proposal.proposer.is_none()); + } + + #[test] + fn test_extract_u32_from_value() { + assert_eq!(extract_u32_from_value("U32(42)"), Some(42)); + assert_eq!(extract_u32_from_value("U64(100)"), Some(100)); + assert_eq!(extract_u32_from_value("U128(256)"), Some(256)); + assert_eq!(extract_u32_from_value("nothing here"), None); + } + + #[test] + fn test_extract_first_u64_after_key() { + let s = "{ index: U64(10), threshold: U32(5), end: U64(1000) }"; + assert_eq!(extract_first_u64_after_key(s, "index"), Some(10)); + assert_eq!(extract_first_u64_after_key(s, "threshold"), Some(5)); + assert_eq!(extract_first_u64_after_key(s, "end"), Some(1000)); + assert_eq!(extract_first_u64_after_key(s, "notfound"), None); + } + + #[test] + fn test_extract_accounts_array_after_key() { + // Create a mock debug string with account addresses + let account1 = [0x01u8; 32]; + let hex1 = hex::encode(account1); + let s = format!("{{ ayes: [0x{}], nays: [] }}", hex1); + + let ayes = extract_accounts_array_after_key(&s, "ayes"); + assert_eq!(ayes.len(), 1); + assert_eq!(ayes[0], AccountId32::from(account1)); + + let nays = extract_accounts_array_after_key(&s, "nays"); + assert!(nays.is_empty()); + } + + #[test] + fn test_extract_proposal_hashes() { + // Simulate a Proposals storage value with multiple hashes + let hash1 = [0xaau8; 32]; + let hash2 = [0xbbu8; 32]; + let hex1 = hex::encode(hash1); + let hex2 = hex::encode(hash2); + + // Create a mock Value debug representation + let mock_val_str = format!("Composite(Unnamed([0x{}, 0x{}]))", hex1, hex2); + let _mock_val = Value::from_bytes(hash1.as_slice()); // Just need a Value for testing + + // Test the extraction logic directly on the string + let mut hashes = Vec::new(); + let mut rem = mock_val_str.as_str(); + + while let Some(pos) = rem.find("0x") { + let hex_str: String = rem[pos + 2..] + .chars() + .take_while(|c| c.is_ascii_hexdigit()) + .collect(); + if hex_str.len() >= 64 { + if let Ok(bytes) = hex::decode(&hex_str[..64]) { + if bytes.len() == 32 { + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + hashes.push(arr); + } + } + } + rem = &rem[pos + 2 + hex_str.len()..]; + } + + assert_eq!(hashes.len(), 2); + assert_eq!(hashes[0], hash1); + assert_eq!(hashes[1], hash2); + } + + #[test] + fn test_extract_call_data() { + // Test extraction logic directly on a debug string that mimics + // how actual chain data looks when formatted with Debug + let call_data = vec![0x01u8, 0x02, 0x03, 0x04]; + let hex_data = hex::encode(&call_data); + + // Simulate how the debug output looks with 0x prefix + let mock_debug_str = format!("Composite(Unnamed([0x{}]))", hex_data); + + // Test the extraction logic directly on the string + let extracted = { + let s = &mock_debug_str; + if let Some(pos) = s.find("0x") { + let hex_str: String = s[pos + 2..] + .chars() + .take_while(|c| c.is_ascii_hexdigit()) + .collect(); + if !hex_str.is_empty() { + hex::decode(&hex_str).unwrap_or_default() + } else { + vec![] + } + } else { + vec![] + } + }; + + // The extracted data should match our original data + assert!(!extracted.is_empty()); + assert_eq!(extracted, call_data); + } + + #[test] + fn test_vote_data_with_members() { + let voter1 = AccountId32::from([1u8; 32]); + let voter2 = AccountId32::from([2u8; 32]); + let voter3 = AccountId32::from([3u8; 32]); + + let vote_data = VoteData { + index: 1, + threshold: 2, + ayes: vec![voter1.clone(), voter2.clone()], + nays: vec![voter3.clone()], + end: 5000, + }; + + assert_eq!(vote_data.ayes.len(), 2); + assert_eq!(vote_data.nays.len(), 1); + assert!(vote_data.ayes.contains(&voter1)); + assert!(vote_data.ayes.contains(&voter2)); + assert!(vote_data.nays.contains(&voter3)); + } + + #[test] + fn test_proposal_with_empty_call_data() { + let hash = [0u8; 32]; + let proposer = AccountId32::from([1u8; 32]); + + let proposal = Proposal { + hash, + index: 0, + proposer: Some(proposer), + call_data: vec![], + threshold: 1, + ayes: vec![], + nays: vec![], + end: 0, + }; + + assert!(proposal.call_data.is_empty()); + assert_eq!(proposal.index, 0); + } + + #[test] + fn test_extract_accounts_array_stops_at_boundary() { + // Test that accounts extraction correctly stops at array boundary + let account1 = [0x01u8; 32]; + let account2 = [0x02u8; 32]; + let account3 = [0x03u8; 32]; // This should NOT be extracted - it's in a different array + let hex1 = hex::encode(account1); + let hex2 = hex::encode(account2); + let hex3 = hex::encode(account3); + + // Simulate debug output with two arrays: ayes and nays + let s = format!( + "{{ ayes: [0x{}, 0x{}], nays: [0x{}] }}", + hex1, hex2, hex3 + ); + + let ayes = extract_accounts_array_after_key(&s, "ayes"); + assert_eq!(ayes.len(), 2, "Should extract exactly 2 accounts from ayes array"); + assert_eq!(ayes[0], AccountId32::from(account1)); + assert_eq!(ayes[1], AccountId32::from(account2)); + + // Verify ayes doesn't contain the nays account + assert!(!ayes.contains(&AccountId32::from(account3)), "ayes should not contain nays account"); + + let nays = extract_accounts_array_after_key(&s, "nays"); + assert_eq!(nays.len(), 1, "Should extract exactly 1 account from nays array"); + assert_eq!(nays[0], AccountId32::from(account3)); + } + + #[test] + fn test_extract_accounts_array_with_nested_brackets() { + // Test with nested structures to ensure bracket tracking works + let account1 = [0xaau8; 32]; + let hex1 = hex::encode(account1); + + // Simulate complex nested structure + let s = format!( + "{{ data: [Composite([0x{}])], other: [] }}", + hex1 + ); + + let accounts = extract_accounts_array_after_key(&s, "data"); + assert_eq!(accounts.len(), 1, "Should extract account from nested structure"); + assert_eq!(accounts[0], AccountId32::from(account1)); + + let other = extract_accounts_array_after_key(&s, "other"); + assert!(other.is_empty(), "other array should be empty"); + } + + #[test] + fn test_extract_first_u64_after_key_various_formats() { + // Test U64 format + let s1 = "{ value: U64(12345) }"; + assert_eq!(extract_first_u64_after_key(s1, "value"), Some(12345)); + + // Test U32 format + let s2 = "{ count: U32(42) }"; + assert_eq!(extract_first_u64_after_key(s2, "count"), Some(42)); + + // Test U128 format + let s3 = "{ big: U128(999999999999) }"; + assert_eq!(extract_first_u64_after_key(s3, "big"), Some(999999999999)); + + // Test mixed - should pick the first one after the key + let s4 = "{ first: U32(10), second: U64(20) }"; + assert_eq!(extract_first_u64_after_key(s4, "first"), Some(10)); + assert_eq!(extract_first_u64_after_key(s4, "second"), Some(20)); + + // Test when key is not found + let s5 = "{ something: U64(100) }"; + assert_eq!(extract_first_u64_after_key(s5, "missing"), None); + + // Test with whitespace + let s6 = "{ spaced: U64( 555 ) }"; + assert_eq!(extract_first_u64_after_key(s6, "spaced"), Some(555)); + } + + #[test] + fn test_extract_accounts_empty_array() { + let s = "{ ayes: [], nays: [] }"; + + let ayes = extract_accounts_array_after_key(s, "ayes"); + assert!(ayes.is_empty()); + + let nays = extract_accounts_array_after_key(s, "nays"); + assert!(nays.is_empty()); + } + + #[test] + fn test_extract_accounts_no_array() { + // Test when key exists but no array follows + let s = "{ ayes: 123 }"; + + let accounts = extract_accounts_array_after_key(s, "ayes"); + assert!(accounts.is_empty()); + } +} diff --git a/src/validator/weights.rs b/src/validator/weights.rs index c421d23..53e2a24 100644 --- a/src/validator/weights.rs +++ b/src/validator/weights.rs @@ -111,7 +111,14 @@ pub async fn reveal_weights( } // Convert uids from u64 to u16 (Subtensor expects Vec) - let uid_u16: Vec = uids.iter().map(|uid| *uid as u16).collect(); + let uid_u16: Vec = uids + .iter() + .map(|uid| { + u16::try_from(*uid).map_err(|_| { + anyhow::anyhow!("UID {} exceeds u16 max value {}", uid, u16::MAX) + }) + }) + .collect::>>()?; let uid_values: Vec = uid_u16 .iter() diff --git a/src/wallet/keyfile.rs b/src/wallet/keyfile.rs new file mode 100644 index 0000000..8973658 --- /dev/null +++ b/src/wallet/keyfile.rs @@ -0,0 +1,736 @@ +//! Keyfile encryption and storage for Bittensor wallets. +//! +//! This module provides functionality to securely store keypairs on disk, +//! compatible with the Python Bittensor SDK keyfile format. +//! +//! ## Keyfile Format +//! +//! The keyfile format uses JSON with the following structure: +//! ```json +//! { +//! "crypto": { +//! "cipher": "secretbox", +//! "ciphertext": "", +//! "cipherparams": {"nonce": ""}, +//! "kdf": "argon2id", +//! "kdfparams": { +//! "salt": "", +//! "n": 65536, +//! "r": 1, +//! "p": 4 +//! } +//! }, +//! "version": 4 +//! } +//! ``` + +use crate::wallet::keypair::{Keypair, KeypairError}; +use argon2::{Argon2, Params, Version}; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; +use crypto_secretbox::{ + aead::{Aead, KeyInit}, + XSalsa20Poly1305, +}; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use thiserror::Error; +use zeroize::Zeroize; + +/// Current keyfile format version +pub const KEYFILE_VERSION: u32 = 4; + +/// Default Argon2 parameters matching Python SDK +const ARGON2_TIME_COST: u32 = 1; +const ARGON2_MEMORY_COST: u32 = 65536; // 64 MiB +const ARGON2_PARALLELISM: u32 = 4; + +/// Errors that can occur during keyfile operations. +#[derive(Debug, Error)] +pub enum KeyfileError { + #[error("Keyfile not found: {0}")] + NotFound(PathBuf), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON parsing error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Invalid keyfile format: {0}")] + InvalidFormat(String), + + #[error("Decryption failed: wrong password or corrupted keyfile")] + DecryptionFailed, + + #[error("Encryption failed: {0}")] + EncryptionFailed(String), + + #[error("Key derivation failed: {0}")] + KeyDerivationFailed(String), + + #[error("Keyfile already exists and overwrite is not enabled")] + AlreadyExists, + + #[error("Keypair error: {0}")] + Keypair(#[from] KeypairError), + + #[error("Base64 decode error: {0}")] + Base64(#[from] base64::DecodeError), + + #[error("Unsupported keyfile version: {0}")] + UnsupportedVersion(u32), + + #[error("Keyfile is not encrypted")] + NotEncrypted, + + #[error("Password required for encrypted keyfile")] + PasswordRequired, + + #[error("Legacy format detected: {0}")] + LegacyFormat(String), +} + +/// Encryption parameters for a keyfile. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KdfParams { + pub salt: String, + #[serde(rename = "n")] + pub memory_cost: u32, + #[serde(rename = "r")] + pub time_cost: u32, + #[serde(rename = "p")] + pub parallelism: u32, +} + +/// Cipher parameters for a keyfile. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CipherParams { + pub nonce: String, +} + +/// Crypto section of the keyfile. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CryptoData { + pub cipher: String, + pub ciphertext: String, + pub cipherparams: CipherParams, + pub kdf: String, + pub kdfparams: KdfParams, +} + +/// The complete keyfile structure. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KeyfileJson { + pub crypto: CryptoData, + pub version: u32, +} + +/// Data structure for encrypted key material. +#[derive(Debug, Clone)] +pub struct KeyfileData { + /// Encrypted key bytes + pub encrypted_key: Vec, + /// 24-byte nonce for XSalsa20Poly1305 + pub nonce: [u8; 24], + /// 16-byte salt for Argon2 + pub salt: [u8; 16], +} + +/// A keyfile represents a keypair stored on disk. +/// +/// The keyfile can be encrypted (password-protected) or unencrypted. +pub struct Keyfile { + path: PathBuf, + keypair: Option, +} + +impl std::fmt::Debug for Keyfile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Keyfile") + .field("path", &self.path) + .field("loaded", &self.keypair.is_some()) + .finish() + } +} + +impl Keyfile { + /// Create a new keyfile handle for the given path. + /// + /// This does not load or create the keyfile on disk. + /// + /// # Arguments + /// * `path` - Path where the keyfile is or will be stored + pub fn new(path: impl AsRef) -> Self { + Self { + path: path.as_ref().to_path_buf(), + keypair: None, + } + } + + /// Get the path to this keyfile. + pub fn path(&self) -> &Path { + &self.path + } + + /// Check if the keyfile exists on disk. + pub fn exists(&self) -> bool { + self.path.exists() + } + + /// Check if the keyfile is encrypted. + /// + /// Returns `false` if the file doesn't exist or can't be read. + pub fn is_encrypted(&self) -> bool { + if !self.exists() { + return false; + } + + match self.read_raw() { + Ok(data) => { + // Try to parse as encrypted JSON format + serde_json::from_slice::(&data).is_ok() + } + Err(_) => false, + } + } + + /// Get the keypair, decrypting if necessary. + /// + /// # Arguments + /// * `password` - Password for decryption (required if encrypted) + /// + /// # Returns + /// The keypair or an error. + pub fn get_keypair(&self, password: Option<&str>) -> Result { + if let Some(ref kp) = self.keypair { + return Ok(kp.clone()); + } + + if !self.exists() { + return Err(KeyfileError::NotFound(self.path.clone())); + } + + let data = self.read_raw()?; + self.decrypt_keypair(&data, password) + } + + /// Store a keypair in this keyfile. + /// + /// # Arguments + /// * `keypair` - The keypair to store + /// * `password` - Optional password for encryption (if None, stores unencrypted) + /// * `overwrite` - Whether to overwrite an existing keyfile + /// + /// # Returns + /// Ok(()) on success, or an error. + pub fn set_keypair( + &mut self, + keypair: Keypair, + password: Option<&str>, + overwrite: bool, + ) -> Result<(), KeyfileError> { + if self.exists() && !overwrite { + return Err(KeyfileError::AlreadyExists); + } + + // Ensure parent directory exists + if let Some(parent) = self.path.parent() { + fs::create_dir_all(parent)?; + } + + let raw_key = keypair.to_bytes(); + + let content = match password { + Some(pass) => { + let keyfile_data = self.encrypt(&raw_key, pass)?; + self.to_json(&keyfile_data)? + } + None => { + // SECURITY WARNING: Storing key without encryption + tracing::warn!( + "Storing keyfile without encryption at {:?}. \ + This is insecure - consider using a password.", + self.path + ); + // Store unencrypted (just the raw key bytes as hex) + // This matches legacy unencrypted format + hex::encode(&raw_key).into_bytes() + } + }; + + // Write atomically by writing to temp file first + // On Unix, set restrictive permissions (0o600) at creation time to avoid + // a race condition where the file is briefly world-readable. + let temp_path = self.path.with_extension("tmp"); + { + #[cfg(unix)] + let mut file = { + use std::os::unix::fs::OpenOptionsExt; + fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(&temp_path)? + }; + #[cfg(not(unix))] + let mut file = fs::File::create(&temp_path)?; + + file.write_all(&content)?; + file.sync_all()?; + } + fs::rename(&temp_path, &self.path)?; + + // On non-Unix platforms, set permissions after rename (best effort) + #[cfg(not(unix))] + { + // No race condition mitigation available on non-Unix + // At least try to restrict permissions after creation + } + + self.keypair = Some(keypair); + Ok(()) + } + + /// Encrypt data using Argon2id + XSalsa20Poly1305. + /// + /// # Arguments + /// * `data` - The data to encrypt + /// * `password` - The encryption password + /// + /// # Returns + /// The encrypted data with salt and nonce. + pub fn encrypt(&self, data: &[u8], password: &str) -> Result { + // Generate random salt and nonce + let mut salt = [0u8; 16]; + let mut nonce = [0u8; 24]; + + use rand::RngCore; + let mut rng = rand::rng(); + rng.fill_bytes(&mut salt); + rng.fill_bytes(&mut nonce); + + // Derive key using Argon2id + let mut key = derive_key(password, &salt)?; + + // Encrypt using XSalsa20Poly1305 + let cipher = XSalsa20Poly1305::new_from_slice(&key) + .map_err(|e| KeyfileError::EncryptionFailed(e.to_string()))?; + + let encrypted_key = cipher + .encrypt(nonce.as_ref().into(), data) + .map_err(|e| KeyfileError::EncryptionFailed(e.to_string()))?; + + // Zeroize the derived key + key.zeroize(); + + Ok(KeyfileData { + encrypted_key, + nonce, + salt, + }) + } + + /// Decrypt data using Argon2id + XSalsa20Poly1305. + /// + /// # Arguments + /// * `data` - The encrypted data with salt and nonce + /// * `password` - The decryption password + /// + /// # Returns + /// The decrypted data. + pub fn decrypt(&self, data: &KeyfileData, password: &str) -> Result, KeyfileError> { + // Derive key using Argon2id + let mut key = derive_key(password, &data.salt)?; + + // Decrypt using XSalsa20Poly1305 + let cipher = XSalsa20Poly1305::new_from_slice(&key) + .map_err(|e| KeyfileError::EncryptionFailed(format!("Failed to create cipher: {}", e)))?; + + let decrypted = cipher + .decrypt(data.nonce.as_ref().into(), data.encrypted_key.as_ref()) + .map_err(|_| KeyfileError::DecryptionFailed)?; + + // Zeroize the derived key + key.zeroize(); + + Ok(decrypted) + } + + /// Re-encrypt the keyfile with a new password or update encryption parameters. + /// + /// # Arguments + /// * `old_password` - Current password (or None if unencrypted) + /// * `new_password` - New password for encryption + /// + /// # Returns + /// Ok(()) on success. + pub fn check_and_update_encryption( + &mut self, + old_password: Option<&str>, + new_password: &str, + ) -> Result<(), KeyfileError> { + let keypair = self.get_keypair(old_password)?; + self.set_keypair(keypair, Some(new_password), true) + } + + /// Read raw bytes from the keyfile. + fn read_raw(&self) -> Result, KeyfileError> { + let mut file = fs::File::open(&self.path)?; + let mut data = Vec::new(); + file.read_to_end(&mut data)?; + Ok(data) + } + + /// Convert KeyfileData to JSON bytes. + fn to_json(&self, data: &KeyfileData) -> Result, KeyfileError> { + let json = KeyfileJson { + crypto: CryptoData { + cipher: "secretbox".to_string(), + ciphertext: BASE64.encode(&data.encrypted_key), + cipherparams: CipherParams { + nonce: BASE64.encode(data.nonce), + }, + kdf: "argon2id".to_string(), + kdfparams: KdfParams { + salt: BASE64.encode(data.salt), + memory_cost: ARGON2_MEMORY_COST, + time_cost: ARGON2_TIME_COST, + parallelism: ARGON2_PARALLELISM, + }, + }, + version: KEYFILE_VERSION, + }; + + serde_json::to_vec_pretty(&json).map_err(KeyfileError::Json) + } + + /// Parse JSON and decrypt to keypair. + fn decrypt_keypair(&self, data: &[u8], password: Option<&str>) -> Result { + // Try to parse as JSON (encrypted format) + if let Ok(json) = serde_json::from_slice::(data) { + return self.decrypt_from_json(&json, password); + } + + // Try as unencrypted hex + if let Ok(hex_str) = std::str::from_utf8(data) { + let hex_str = hex_str.trim(); + if let Ok(key_bytes) = hex::decode(hex_str) { + return Keypair::from_bytes(&key_bytes).map_err(KeyfileError::Keypair); + } + } + + // Try as raw bytes (legacy unencrypted) + if data.len() >= 32 { + if let Ok(keypair) = Keypair::from_bytes(data) { + return Ok(keypair); + } + } + + // Check for legacy formats + if is_legacy_format(data) { + return Err(KeyfileError::LegacyFormat( + "Please migrate this keyfile using migrate_legacy_keyfile()".to_string(), + )); + } + + Err(KeyfileError::InvalidFormat( + "Could not parse keyfile data".to_string(), + )) + } + + /// Decrypt keypair from parsed JSON. + fn decrypt_from_json( + &self, + json: &KeyfileJson, + password: Option<&str>, + ) -> Result { + if json.version > KEYFILE_VERSION { + return Err(KeyfileError::UnsupportedVersion(json.version)); + } + + let password = password.ok_or(KeyfileError::PasswordRequired)?; + + // Decode base64 fields + let ciphertext = BASE64.decode(&json.crypto.ciphertext)?; + let nonce_bytes = BASE64.decode(&json.crypto.cipherparams.nonce)?; + let salt_bytes = BASE64.decode(&json.crypto.kdfparams.salt)?; + + if nonce_bytes.len() != 24 { + return Err(KeyfileError::InvalidFormat(format!( + "Invalid nonce length: expected 24, got {}", + nonce_bytes.len() + ))); + } + + if salt_bytes.len() != 16 { + return Err(KeyfileError::InvalidFormat(format!( + "Invalid salt length: expected 16, got {}", + salt_bytes.len() + ))); + } + + let mut nonce = [0u8; 24]; + let mut salt = [0u8; 16]; + nonce.copy_from_slice(&nonce_bytes); + salt.copy_from_slice(&salt_bytes); + + let keyfile_data = KeyfileData { + encrypted_key: ciphertext, + nonce, + salt, + }; + + let key_bytes = self.decrypt(&keyfile_data, password)?; + Keypair::from_bytes(&key_bytes).map_err(KeyfileError::Keypair) + } +} + +/// Derive an encryption key using Argon2id. +fn derive_key(password: &str, salt: &[u8; 16]) -> Result<[u8; 32], KeyfileError> { + let params = Params::new( + ARGON2_MEMORY_COST, + ARGON2_TIME_COST, + ARGON2_PARALLELISM, + Some(32), + ) + .map_err(|e| KeyfileError::KeyDerivationFailed(e.to_string()))?; + + let argon2 = Argon2::new(argon2::Algorithm::Argon2id, Version::V0x13, params); + + let mut key = [0u8; 32]; + argon2 + .hash_password_into(password.as_bytes(), salt, &mut key) + .map_err(|e| KeyfileError::KeyDerivationFailed(e.to_string()))?; + + Ok(key) +} + +/// Check if data is in a legacy (pre-v4) format. +/// +/// # Arguments +/// * `data` - The raw keyfile data +/// +/// # Returns +/// `true` if the data appears to be in a legacy format. +pub fn is_legacy_format(data: &[u8]) -> bool { + // Check for old JSON formats with different structure + if let Ok(value) = serde_json::from_slice::(data) { + // Legacy formats might have different fields + if let Some(obj) = value.as_object() { + // Check for pre-v4 format markers + if obj.contains_key("secretPhrase") { + return true; + } + if obj.contains_key("data") && !obj.contains_key("crypto") { + return true; + } + // Version check + if let Some(version) = obj.get("version") { + if let Some(v) = version.as_u64() { + if v < KEYFILE_VERSION as u64 { + return true; + } + } + } + } + } + false +} + +/// Migrate a legacy keyfile to the current format. +/// +/// # Arguments +/// * `path` - Path to the legacy keyfile +/// * `password` - Password for encryption (may be needed for old encrypted formats) +/// * `new_password` - Password for the new format +/// +/// # Returns +/// Ok(()) on success. +pub fn migrate_legacy_keyfile( + path: &Path, + password: Option<&str>, + new_password: &str, +) -> Result<(), KeyfileError> { + let mut data = Vec::new(); + fs::File::open(path)?.read_to_end(&mut data)?; + + if !is_legacy_format(&data) { + return Err(KeyfileError::InvalidFormat( + "Not a legacy format keyfile".to_string(), + )); + } + + // Try to extract keypair from legacy format + let keypair = parse_legacy_keyfile(&data, password)?; + + // Create new keyfile with current format + let mut keyfile = Keyfile::new(path); + keyfile.set_keypair(keypair, Some(new_password), true)?; + + Ok(()) +} + +/// Parse a legacy keyfile to extract the keypair. +fn parse_legacy_keyfile(data: &[u8], password: Option<&str>) -> Result { + if let Ok(value) = serde_json::from_slice::(data) { + if let Some(obj) = value.as_object() { + // Handle secretPhrase format + if let Some(phrase) = obj.get("secretPhrase").and_then(|v| v.as_str()) { + return Keypair::from_mnemonic(phrase, password).map_err(KeyfileError::Keypair); + } + + // Handle old encrypted format with "data" field + if let Some(data_field) = obj.get("data").and_then(|v| v.as_str()) { + let key_bytes = hex::decode(data_field) + .map_err(|e| KeyfileError::InvalidFormat(e.to_string()))?; + return Keypair::from_bytes(&key_bytes).map_err(KeyfileError::Keypair); + } + } + } + + Err(KeyfileError::InvalidFormat( + "Could not parse legacy keyfile".to_string(), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn test_encrypt_decrypt() { + let keyfile = Keyfile::new("/tmp/test"); + let data = b"secret data"; + let password = "test_password"; + + let encrypted = keyfile.encrypt(data, password).unwrap(); + let decrypted = keyfile.decrypt(&encrypted, password).unwrap(); + + assert_eq!(data.as_slice(), decrypted.as_slice()); + } + + #[test] + fn test_encrypt_decrypt_wrong_password() { + let keyfile = Keyfile::new("/tmp/test"); + let data = b"secret data"; + + let encrypted = keyfile.encrypt(data, "correct_password").unwrap(); + let result = keyfile.decrypt(&encrypted, "wrong_password"); + + assert!(result.is_err()); + } + + #[test] + fn test_keyfile_roundtrip_encrypted() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_key"); + + let original = Keypair::generate(); + let password = "test_password"; + + { + let mut keyfile = Keyfile::new(&path); + keyfile + .set_keypair(original.clone(), Some(password), false) + .unwrap(); + } + + { + let keyfile = Keyfile::new(&path); + assert!(keyfile.exists()); + assert!(keyfile.is_encrypted()); + + let loaded = keyfile.get_keypair(Some(password)).unwrap(); + assert_eq!(original.public_key(), loaded.public_key()); + } + } + + #[test] + fn test_keyfile_roundtrip_unencrypted() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_key_unenc"); + + let original = Keypair::generate(); + + { + let mut keyfile = Keyfile::new(&path); + keyfile.set_keypair(original.clone(), None, false).unwrap(); + } + + { + let keyfile = Keyfile::new(&path); + assert!(keyfile.exists()); + assert!(!keyfile.is_encrypted()); + + let loaded = keyfile.get_keypair(None).unwrap(); + assert_eq!(original.public_key(), loaded.public_key()); + } + } + + #[test] + fn test_keyfile_no_overwrite() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_key_no_ow"); + + let keypair = Keypair::generate(); + + let mut keyfile = Keyfile::new(&path); + keyfile.set_keypair(keypair.clone(), None, false).unwrap(); + + // Should fail without overwrite + let result = keyfile.set_keypair(keypair.clone(), None, false); + assert!(matches!(result, Err(KeyfileError::AlreadyExists))); + + // Should succeed with overwrite + keyfile.set_keypair(keypair, None, true).unwrap(); + } + + #[test] + fn test_keyfile_not_found() { + let keyfile = Keyfile::new("/nonexistent/path/key"); + let result = keyfile.get_keypair(None); + assert!(matches!(result, Err(KeyfileError::NotFound(_)))); + } + + #[test] + fn test_is_legacy_format() { + // Legacy secretPhrase format + let legacy1 = br#"{"secretPhrase": "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"}"#; + assert!(is_legacy_format(legacy1)); + + // Legacy data format + let legacy2 = br#"{"data": "0123456789abcdef", "version": 2}"#; + assert!(is_legacy_format(legacy2)); + + // Current format should not be detected as legacy + let current = br#"{"crypto": {"cipher": "secretbox"}, "version": 4}"#; + assert!(!is_legacy_format(current)); + } + + #[test] + fn test_keyfile_password_required() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_enc_key"); + + let keypair = Keypair::generate(); + let mut keyfile = Keyfile::new(&path); + keyfile + .set_keypair(keypair, Some("password"), false) + .unwrap(); + + // Create a fresh Keyfile instance to avoid cached keypair + // This simulates loading from disk like a real application would + let keyfile2 = Keyfile::new(&path); + + // Should require password when loading from encrypted file + let result = keyfile2.get_keypair(None); + if let Err(ref e) = result { + eprintln!("Got error: {:?}", e); + } + assert!(matches!(result, Err(KeyfileError::PasswordRequired))); + } +} diff --git a/src/wallet/keypair.rs b/src/wallet/keypair.rs new file mode 100644 index 0000000..e5b7956 --- /dev/null +++ b/src/wallet/keypair.rs @@ -0,0 +1,434 @@ +//! Keypair management for Bittensor wallets. +//! +//! This module provides SR25519 keypair functionality for signing and verification, +//! compatible with the Substrate ecosystem and the Python Bittensor SDK. + +// Allow unused_assignments - the ZeroizeOnDrop derive macro generates code that clippy +// incorrectly flags as unused assignments when it reads/writes struct fields for zeroization +#![allow(unused_assignments)] + +use crate::wallet::mnemonic::{Mnemonic, MnemonicError}; +use sp_core::{ + crypto::{Ss58AddressFormat, Ss58Codec}, + sr25519, Pair, +}; +use thiserror::Error; +use zeroize::{Zeroize, ZeroizeOnDrop}; + +/// Bittensor SS58 address format (42 = "bt") +pub const BITTENSOR_SS58_FORMAT: u16 = 42; + +/// Errors that can occur during keypair operations. +#[derive(Debug, Error)] +pub enum KeypairError { + #[error("Invalid seed length: expected 32 bytes, got {0}")] + InvalidSeedLength(usize), + + #[error("Invalid URI: {0}")] + InvalidUri(String), + + #[error("Mnemonic error: {0}")] + Mnemonic(#[from] MnemonicError), + + #[error("Invalid signature length: expected 64 bytes, got {0}")] + InvalidSignatureLength(usize), + + #[error("Signature verification failed")] + VerificationFailed, + + #[error("Key derivation error: {0}")] + DerivationError(String), +} + +/// An SR25519 keypair for signing transactions and messages. +/// +/// This provides full keypair functionality including signing and verification. +/// +/// # Security Note +/// +/// The underlying `sr25519::Pair` type from sp_core does not implement `Zeroize`, +/// meaning the private key material may remain in memory after this struct is dropped. +/// For maximum security in sensitive applications, consider: +/// - Using short-lived Keypair instances +/// - Explicitly dropping Keypairs when no longer needed +/// - Using memory-safe practices at the application level +/// +/// The `public_key` field IS properly zeroized on drop. +#[derive(ZeroizeOnDrop)] +pub struct Keypair { + /// The underlying sr25519 pair. Note: This is NOT zeroized on drop as + /// sp_core::sr25519::Pair does not implement Zeroize. + #[zeroize(skip)] + pair: sr25519::Pair, + /// The 32-byte public key. This field IS zeroized on drop. + public_key: [u8; 32], + /// The SS58-encoded address. Skipped from zeroization as it's derived from public key. + #[zeroize(skip)] + ss58_address: String, +} + +impl Clone for Keypair { + fn clone(&self) -> Self { + Self { + pair: self.pair.clone(), + public_key: self.public_key, + ss58_address: self.ss58_address.clone(), + } + } +} + +impl std::fmt::Debug for Keypair { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Don't expose the private key in debug output + f.debug_struct("Keypair") + .field("ss58_address", &self.ss58_address) + .finish() + } +} + +impl Keypair { + /// Create a keypair from an sr25519 pair. + fn from_pair(pair: sr25519::Pair) -> Self { + let public = pair.public(); + let public_key: [u8; 32] = public.0; + let ss58_address = + public.to_ss58check_with_version(Ss58AddressFormat::custom(BITTENSOR_SS58_FORMAT)); + + Self { + pair, + public_key, + ss58_address, + } + } + + /// Generate a new random keypair. + /// + /// # Returns + /// A new randomly generated keypair. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Keypair; + /// let keypair = Keypair::generate(); + /// println!("Address: {}", keypair.ss58_address()); + /// ``` + pub fn generate() -> Self { + let (pair, _) = sr25519::Pair::generate(); + Self::from_pair(pair) + } + + /// Create a keypair from a BIP39 mnemonic phrase. + /// + /// # Arguments + /// * `mnemonic` - A valid BIP39 mnemonic phrase + /// * `password` - Optional password for additional security + /// + /// # Returns + /// The derived keypair or an error. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Keypair; + /// let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + /// let keypair = Keypair::from_mnemonic(phrase, None).unwrap(); + /// ``` + pub fn from_mnemonic(mnemonic: &str, password: Option<&str>) -> Result { + let mnemonic_obj = Mnemonic::from_phrase(mnemonic)?; + Self::from_mnemonic_obj(&mnemonic_obj, password) + } + + /// Create a keypair from a Mnemonic object. + /// + /// # Arguments + /// * `mnemonic` - A Mnemonic object + /// * `password` - Optional password for additional security + /// + /// # Returns + /// The derived keypair. + pub fn from_mnemonic_obj( + mnemonic: &Mnemonic, + password: Option<&str>, + ) -> Result { + // Use the mnemonic phrase directly with sp_core's from_phrase + // This matches the Substrate/Polkadot standard derivation + let pass = password.unwrap_or(""); + let (pair, _seed) = sr25519::Pair::from_phrase(mnemonic.phrase(), Some(pass)) + .map_err(|e| KeypairError::DerivationError(format!("{:?}", e)))?; + + Ok(Self::from_pair(pair)) + } + + /// Create a keypair from a 32-byte seed. + /// + /// # Arguments + /// * `seed` - A 32-byte seed + /// + /// # Returns + /// The derived keypair or an error if the seed is invalid. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Keypair; + /// let seed = [0u8; 32]; + /// let keypair = Keypair::from_seed(&seed).unwrap(); + /// ``` + pub fn from_seed(seed: &[u8]) -> Result { + if seed.len() != 32 { + return Err(KeypairError::InvalidSeedLength(seed.len())); + } + + let mut seed_arr = [0u8; 32]; + seed_arr.copy_from_slice(seed); + + let pair = sr25519::Pair::from_seed(&seed_arr); + + // Zeroize the seed copy + seed_arr.zeroize(); + + Ok(Self::from_pair(pair)) + } + + /// Create a keypair from a Substrate URI (secret phrase with optional derivation path). + /// + /// # Arguments + /// * `uri` - A secret URI (e.g., "//Alice" or "word word word//derive/path") + /// + /// # Returns + /// The derived keypair or an error. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Keypair; + /// let keypair = Keypair::from_uri("//Alice").unwrap(); + /// ``` + pub fn from_uri(uri: &str) -> Result { + let pair = sr25519::Pair::from_string(uri, None) + .map_err(|e| KeypairError::InvalidUri(format!("{:?}", e)))?; + Ok(Self::from_pair(pair)) + } + + /// Get the public key as raw bytes. + /// + /// # Returns + /// A reference to the 32-byte public key. + pub fn public_key(&self) -> &[u8; 32] { + &self.public_key + } + + /// Get the SS58 address with Bittensor format (prefix 42). + /// + /// # Returns + /// The SS58-encoded address string. + pub fn ss58_address(&self) -> &str { + &self.ss58_address + } + + /// Get the underlying sr25519 pair. + /// + /// This can be used for advanced operations or integration with other Substrate libraries. + pub fn pair(&self) -> &sr25519::Pair { + &self.pair + } + + /// Sign a message and return the signature. + /// + /// # Arguments + /// * `message` - The message to sign + /// + /// # Returns + /// A 64-byte signature. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Keypair; + /// let keypair = Keypair::generate(); + /// let message = b"Hello, Bittensor!"; + /// let signature = keypair.sign(message); + /// assert!(keypair.verify(message, &signature)); + /// ``` + pub fn sign(&self, message: &[u8]) -> [u8; 64] { + let signature = self.pair.sign(message); + signature.0 + } + + /// Verify a signature against a message using this keypair's public key. + /// + /// # Arguments + /// * `message` - The original message + /// * `signature` - The signature to verify (64 bytes) + /// + /// # Returns + /// `true` if the signature is valid. + pub fn verify(&self, message: &[u8], signature: &[u8]) -> bool { + if signature.len() != 64 { + return false; + } + + let mut sig_arr = [0u8; 64]; + sig_arr.copy_from_slice(signature); + + let sig = sr25519::Signature::from_raw(sig_arr); + sr25519::Pair::verify(&sig, message, &self.pair.public()) + } + + /// Verify a signature against a message using a public key. + /// + /// # Arguments + /// * `message` - The original message + /// * `signature` - The signature to verify (64 bytes) + /// * `public_key` - The public key (32 bytes) + /// + /// # Returns + /// `true` if the signature is valid. + pub fn verify_with_public(message: &[u8], signature: &[u8], public_key: &[u8; 32]) -> bool { + if signature.len() != 64 { + return false; + } + + let mut sig_arr = [0u8; 64]; + sig_arr.copy_from_slice(signature); + + let sig = sr25519::Signature::from_raw(sig_arr); + let public = sr25519::Public::from_raw(*public_key); + + sr25519::Pair::verify(&sig, message, &public) + } + + /// Export the keypair as bytes (raw seed if available, otherwise serialized). + /// + /// WARNING: This exposes the private key. Handle with care. + /// + /// # Returns + /// The raw key bytes suitable for storage. + pub fn to_bytes(&self) -> Vec { + // Export the full keypair in a format that can be restored + // We use the raw seed bytes (64 bytes for sr25519) + self.pair.to_raw_vec() + } + + /// Create a keypair from exported bytes. + /// + /// # Arguments + /// * `bytes` - The raw keypair bytes + /// + /// # Returns + /// The restored keypair or an error. + pub fn from_bytes(bytes: &[u8]) -> Result { + // Try to restore from raw vec + let pair = sr25519::Pair::from_seed_slice(bytes) + .map_err(|e| KeypairError::DerivationError(format!("Failed to restore keypair: {:?}", e)))?; + Ok(Self::from_pair(pair)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate() { + let keypair = Keypair::generate(); + assert_eq!(keypair.public_key().len(), 32); + assert!(keypair.ss58_address().starts_with('5')); // SS58 prefix for substrate + } + + #[test] + fn test_from_mnemonic() { + let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + let keypair = Keypair::from_mnemonic(phrase, None).unwrap(); + + // Should be deterministic + let keypair2 = Keypair::from_mnemonic(phrase, None).unwrap(); + assert_eq!(keypair.public_key(), keypair2.public_key()); + assert_eq!(keypair.ss58_address(), keypair2.ss58_address()); + } + + #[test] + fn test_from_mnemonic_with_password() { + let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + let keypair_no_pass = Keypair::from_mnemonic(phrase, None).unwrap(); + let keypair_with_pass = Keypair::from_mnemonic(phrase, Some("password")).unwrap(); + + // Different passwords should produce different keys + assert_ne!(keypair_no_pass.public_key(), keypair_with_pass.public_key()); + } + + #[test] + fn test_from_seed() { + let seed = [42u8; 32]; + let keypair = Keypair::from_seed(&seed).unwrap(); + + // Should be deterministic + let keypair2 = Keypair::from_seed(&seed).unwrap(); + assert_eq!(keypair.public_key(), keypair2.public_key()); + } + + #[test] + fn test_from_seed_invalid_length() { + let seed = [0u8; 16]; + assert!(Keypair::from_seed(&seed).is_err()); + } + + #[test] + fn test_from_uri() { + let keypair = Keypair::from_uri("//Alice").unwrap(); + assert!(!keypair.ss58_address().is_empty()); + + // Should be deterministic + let keypair2 = Keypair::from_uri("//Alice").unwrap(); + assert_eq!(keypair.public_key(), keypair2.public_key()); + } + + #[test] + fn test_sign_and_verify() { + let keypair = Keypair::generate(); + let message = b"Hello, Bittensor!"; + + let signature = keypair.sign(message); + assert_eq!(signature.len(), 64); + + assert!(keypair.verify(message, &signature)); + + // Wrong message should fail + assert!(!keypair.verify(b"Wrong message", &signature)); + } + + #[test] + fn test_verify_with_public() { + let keypair = Keypair::generate(); + let message = b"Test message"; + let signature = keypair.sign(message); + + assert!(Keypair::verify_with_public( + message, + &signature, + keypair.public_key() + )); + } + + #[test] + fn test_to_and_from_bytes() { + let original = Keypair::generate(); + let bytes = original.to_bytes(); + + let restored = Keypair::from_bytes(&bytes).unwrap(); + assert_eq!(original.public_key(), restored.public_key()); + + // Verify signing still works + let message = b"Test"; + let sig = original.sign(message); + assert!(restored.verify(message, &sig)); + } + + #[test] + fn test_invalid_signature_length() { + let keypair = Keypair::generate(); + let message = b"Test"; + + // Too short + assert!(!keypair.verify(message, &[0u8; 32])); + + // Too long + assert!(!keypair.verify(message, &[0u8; 128])); + } +} diff --git a/src/wallet/mnemonic.rs b/src/wallet/mnemonic.rs new file mode 100644 index 0000000..89d9e0c --- /dev/null +++ b/src/wallet/mnemonic.rs @@ -0,0 +1,285 @@ +//! BIP39 mnemonic generation and recovery for wallet creation. +//! +//! This module provides functionality to generate and validate BIP39 mnemonics, +//! which are used to create and recover wallet keypairs. + +// Allow unused_assignments - the ZeroizeOnDrop derive macro generates code that clippy +// incorrectly flags as unused assignments when it reads/writes struct fields for zeroization +#![allow(unused_assignments)] + +use bip39::Mnemonic as Bip39Mnemonic; +use thiserror::Error; +use zeroize::{Zeroize, ZeroizeOnDrop}; + +/// Errors that can occur during mnemonic operations. +#[derive(Debug, Error)] +pub enum MnemonicError { + #[error("Invalid word count: {0}. Must be 12, 15, 18, 21, or 24")] + InvalidWordCount(usize), + + #[error("Invalid mnemonic phrase: {0}")] + InvalidPhrase(String), + + #[error("Entropy generation failed: {0}")] + EntropyError(String), +} + +/// A BIP39 mnemonic phrase for wallet generation and recovery. +/// +/// The mnemonic is securely zeroed from memory when dropped. +#[derive(Clone, ZeroizeOnDrop)] +pub struct Mnemonic { + #[zeroize(skip)] + inner: Bip39Mnemonic, + phrase: String, + words: Vec, +} + +impl std::fmt::Debug for Mnemonic { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Don't expose the actual phrase in debug output + f.debug_struct("Mnemonic") + .field("word_count", &self.words.len()) + .finish() + } +} + +impl Mnemonic { + /// Generate a new 12-word mnemonic phrase. + /// + /// # Returns + /// A new randomly generated mnemonic. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Mnemonic; + /// let mnemonic = Mnemonic::generate(); + /// assert_eq!(mnemonic.word_count(), 12); + /// ``` + pub fn generate() -> Self { + Self::generate_with_words(12).expect("12 words is always valid") + } + + /// Generate a new mnemonic with the specified number of words. + /// + /// # Arguments + /// * `word_count` - Number of words (12, 15, 18, 21, or 24) + /// + /// # Returns + /// A new mnemonic or an error if the word count is invalid. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Mnemonic; + /// let mnemonic = Mnemonic::generate_with_words(24).unwrap(); + /// assert_eq!(mnemonic.word_count(), 24); + /// ``` + pub fn generate_with_words(word_count: usize) -> Result { + let entropy_bits = match word_count { + 12 => 128, + 15 => 160, + 18 => 192, + 21 => 224, + 24 => 256, + _ => return Err(MnemonicError::InvalidWordCount(word_count)), + }; + + let entropy_bytes = entropy_bits / 8; + let mut entropy = vec![0u8; entropy_bytes]; + getrandom(&mut entropy).map_err(|e| MnemonicError::EntropyError(e.to_string()))?; + + let inner = Bip39Mnemonic::from_entropy(&entropy) + .map_err(|e| MnemonicError::EntropyError(e.to_string()))?; + + // Zeroize entropy after use + entropy.zeroize(); + + let phrase = inner.to_string(); + let words: Vec = phrase.split_whitespace().map(String::from).collect(); + + Ok(Self { + inner, + phrase, + words, + }) + } + + /// Create a mnemonic from an existing phrase. + /// + /// # Arguments + /// * `phrase` - A valid BIP39 mnemonic phrase + /// + /// # Returns + /// The parsed mnemonic or an error if the phrase is invalid. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Mnemonic; + /// let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + /// let mnemonic = Mnemonic::from_phrase(phrase).unwrap(); + /// ``` + pub fn from_phrase(phrase: &str) -> Result { + let normalized = phrase.trim().to_lowercase(); + let inner = Bip39Mnemonic::parse_normalized(&normalized) + .map_err(|e| MnemonicError::InvalidPhrase(e.to_string()))?; + + let phrase = inner.to_string(); + let words: Vec = phrase.split_whitespace().map(String::from).collect(); + + Ok(Self { + inner, + phrase, + words, + }) + } + + /// Validate a mnemonic phrase without creating a Mnemonic object. + /// + /// # Arguments + /// * `phrase` - The mnemonic phrase to validate + /// + /// # Returns + /// `true` if the phrase is a valid BIP39 mnemonic. + pub fn validate(phrase: &str) -> bool { + let normalized = phrase.trim().to_lowercase(); + Bip39Mnemonic::parse_normalized(&normalized).is_ok() + } + + /// Get the mnemonic phrase as a string. + /// + /// # Returns + /// The mnemonic phrase. + pub fn phrase(&self) -> &str { + &self.phrase + } + + /// Get the individual words of the mnemonic. + /// + /// # Returns + /// A slice of the mnemonic words. + pub fn words(&self) -> &[String] { + &self.words + } + + /// Get the number of words in the mnemonic. + /// + /// # Returns + /// The word count (12, 15, 18, 21, or 24). + pub fn word_count(&self) -> usize { + self.words.len() + } + + /// Convert the mnemonic to a seed for key derivation. + /// + /// # Arguments + /// * `password` - Optional password for additional security (BIP39 passphrase) + /// + /// # Returns + /// A 64-byte seed suitable for key derivation. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Mnemonic; + /// let mnemonic = Mnemonic::generate(); + /// let seed = mnemonic.to_seed(None); + /// assert_eq!(seed.len(), 64); + /// ``` + pub fn to_seed(&self, password: Option<&str>) -> [u8; 64] { + let passphrase = password.unwrap_or(""); + self.inner.to_seed(passphrase) + } + + /// Convert the mnemonic to entropy bytes. + /// + /// # Returns + /// The underlying entropy as bytes. + pub fn to_entropy(&self) -> Vec { + self.inner.to_entropy() + } +} + +/// Generate random bytes using the system's secure random number generator. +fn getrandom(buf: &mut [u8]) -> Result<(), MnemonicError> { + use rand::RngCore; + let mut rng = rand::rng(); + rng.fill_bytes(buf); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_12_words() { + let mnemonic = Mnemonic::generate(); + assert_eq!(mnemonic.word_count(), 12); + assert!(Mnemonic::validate(mnemonic.phrase())); + } + + #[test] + fn test_generate_24_words() { + let mnemonic = Mnemonic::generate_with_words(24).unwrap(); + assert_eq!(mnemonic.word_count(), 24); + assert!(Mnemonic::validate(mnemonic.phrase())); + } + + #[test] + fn test_invalid_word_count() { + assert!(Mnemonic::generate_with_words(13).is_err()); + assert!(Mnemonic::generate_with_words(10).is_err()); + } + + #[test] + fn test_from_phrase() { + let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + let mnemonic = Mnemonic::from_phrase(phrase).unwrap(); + assert_eq!(mnemonic.word_count(), 12); + assert_eq!(mnemonic.phrase(), phrase); + } + + #[test] + fn test_from_phrase_with_extra_whitespace() { + let phrase = " abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about "; + let mnemonic = Mnemonic::from_phrase(phrase).unwrap(); + assert_eq!(mnemonic.word_count(), 12); + } + + #[test] + fn test_invalid_phrase() { + let result = Mnemonic::from_phrase("invalid mnemonic phrase"); + assert!(result.is_err()); + } + + #[test] + fn test_validate() { + let valid = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + assert!(Mnemonic::validate(valid)); + + let invalid = "invalid mnemonic phrase that is not valid"; + assert!(!Mnemonic::validate(invalid)); + } + + #[test] + fn test_to_seed() { + let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + let mnemonic = Mnemonic::from_phrase(phrase).unwrap(); + + let seed_no_pass = mnemonic.to_seed(None); + let seed_with_pass = mnemonic.to_seed(Some("password")); + + assert_eq!(seed_no_pass.len(), 64); + assert_eq!(seed_with_pass.len(), 64); + assert_ne!(seed_no_pass, seed_with_pass); + } + + #[test] + fn test_deterministic_seed() { + let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + let m1 = Mnemonic::from_phrase(phrase).unwrap(); + let m2 = Mnemonic::from_phrase(phrase).unwrap(); + + assert_eq!(m1.to_seed(None), m2.to_seed(None)); + assert_eq!(m1.to_seed(Some("test")), m2.to_seed(Some("test"))); + } +} diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs new file mode 100644 index 0000000..ac78161 --- /dev/null +++ b/src/wallet/mod.rs @@ -0,0 +1,200 @@ +//! Wallet management for Bittensor. +//! +//! This module provides comprehensive wallet functionality compatible with the +//! Python Bittensor SDK, including: +//! +//! - **Mnemonic generation and recovery** (BIP39) +//! - **Keypair management** (SR25519) +//! - **Keyfile encryption and storage** (Argon2id + NaCl secretbox) +//! - **Wallet creation and management** (coldkey/hotkey) +//! +//! ## Quick Start +//! +//! ### Create a new wallet +//! +//! ```no_run +//! use bittensor_rs::wallet::Wallet; +//! +//! // Create a new wallet with encrypted keys +//! let wallet = Wallet::create("my_wallet", "default", Some("password")).unwrap(); +//! +//! // Get the SS58 addresses +//! let coldkey_addr = wallet.coldkey_ss58(Some("password")).unwrap(); +//! let hotkey_addr = wallet.hotkey_ss58(Some("password")).unwrap(); +//! +//! println!("Coldkey: {}", coldkey_addr); +//! println!("Hotkey: {}", hotkey_addr); +//! ``` +//! +//! ### Recover a wallet from mnemonic +//! +//! ```no_run +//! use bittensor_rs::wallet::Wallet; +//! +//! let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; +//! let wallet = Wallet::regenerate_coldkey("recovered", mnemonic, Some("password")).unwrap(); +//! ``` +//! +//! ### Generate a mnemonic +//! +//! ``` +//! use bittensor_rs::wallet::Mnemonic; +//! +//! // Generate 12-word mnemonic +//! let mnemonic = Mnemonic::generate(); +//! println!("Save this: {}", mnemonic.phrase()); +//! +//! // Generate 24-word mnemonic for extra security +//! let mnemonic24 = Mnemonic::generate_with_words(24).unwrap(); +//! ``` +//! +//! ### Sign and verify messages +//! +//! ``` +//! use bittensor_rs::wallet::Keypair; +//! +//! let keypair = Keypair::generate(); +//! let message = b"Hello, Bittensor!"; +//! +//! let signature = keypair.sign(message); +//! assert!(keypair.verify(message, &signature)); +//! ``` +//! +//! ## Keyfile Format +//! +//! This module uses a keyfile format compatible with the Python SDK: +//! +//! ```json +//! { +//! "crypto": { +//! "cipher": "secretbox", +//! "ciphertext": "", +//! "cipherparams": {"nonce": ""}, +//! "kdf": "argon2id", +//! "kdfparams": {"salt": "", "n": 65536, "r": 1, "p": 4} +//! }, +//! "version": 4 +//! } +//! ``` +//! +//! ## Security Notes +//! +//! - All sensitive data (seeds, private keys, mnemonics) is securely zeroed from +//! memory when dropped using the `zeroize` crate. +//! - Keyfiles use Argon2id for key derivation (memory-hard, resistant to GPU attacks) +//! - Encryption uses XSalsa20-Poly1305 (NaCl secretbox) +//! - File permissions are set to 0600 on Unix systems + +pub mod keyfile; +pub mod keypair; +pub mod mnemonic; +#[allow(clippy::module_inception)] +pub mod wallet; + +// Re-export main types at module level +pub use keyfile::{ + is_legacy_format, migrate_legacy_keyfile, Keyfile, KeyfileData, KeyfileError, KeyfileJson, + KEYFILE_VERSION, +}; +pub use keypair::{Keypair, KeypairError, BITTENSOR_SS58_FORMAT}; +pub use mnemonic::{Mnemonic, MnemonicError}; +pub use wallet::{ + default_wallet_path, list_wallets, list_wallets_at, wallet_path, Wallet, WalletError, +}; + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn test_full_wallet_workflow() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + // Generate mnemonic + let coldkey_mnemonic = Mnemonic::generate(); + let hotkey_mnemonic = Mnemonic::generate(); + + // Create wallet with mnemonics + let mut wallet = Wallet::new("test_wallet", "default", Some(base_path)).unwrap(); + wallet + .create_coldkey(Some("password"), Some(coldkey_mnemonic.phrase()), false) + .unwrap(); + wallet + .create_hotkey(Some("password"), Some(hotkey_mnemonic.phrase()), false) + .unwrap(); + + // Verify wallet exists + assert!(wallet.exists()); + + // Get keypairs + let coldkey = wallet.coldkey_keypair(Some("password")).unwrap(); + let hotkey = wallet.hotkey_keypair(Some("password")).unwrap(); + + // Verify addresses + assert!(!coldkey.ss58_address().is_empty()); + assert!(!hotkey.ss58_address().is_empty()); + + // Sign and verify + let message = b"test message"; + let signature = coldkey.sign(message); + assert!(coldkey.verify(message, &signature)); + + // Recover wallet with same mnemonic + let mut recovered = Wallet::new("recovered", "default", Some(base_path)).unwrap(); + recovered + .create_coldkey(Some("password"), Some(coldkey_mnemonic.phrase()), false) + .unwrap(); + + let recovered_coldkey = recovered.coldkey_keypair(Some("password")).unwrap(); + assert_eq!(coldkey.ss58_address(), recovered_coldkey.ss58_address()); + } + + #[test] + fn test_keyfile_python_compatibility() { + // This test verifies the JSON format matches Python SDK expectations + let dir = tempdir().unwrap(); + let path = dir.path().join("test_keyfile"); + + let keypair = Keypair::generate(); + let mut keyfile = Keyfile::new(&path); + keyfile + .set_keypair(keypair.clone(), Some("password"), false) + .unwrap(); + + // Read and parse the JSON + let content = std::fs::read_to_string(&path).unwrap(); + let json: serde_json::Value = serde_json::from_str(&content).unwrap(); + + // Verify structure + assert_eq!(json["version"], 4); + assert_eq!(json["crypto"]["cipher"], "secretbox"); + assert_eq!(json["crypto"]["kdf"], "argon2id"); + assert!(json["crypto"]["ciphertext"].as_str().is_some()); + assert!(json["crypto"]["cipherparams"]["nonce"].as_str().is_some()); + assert!(json["crypto"]["kdfparams"]["salt"].as_str().is_some()); + } + + #[test] + fn test_keypair_from_uri() { + // Test well-known development accounts + let alice = Keypair::from_uri("//Alice").unwrap(); + let bob = Keypair::from_uri("//Bob").unwrap(); + + assert_ne!(alice.ss58_address(), bob.ss58_address()); + + // Verify deterministic + let alice2 = Keypair::from_uri("//Alice").unwrap(); + assert_eq!(alice.ss58_address(), alice2.ss58_address()); + } + + #[test] + fn test_mnemonic_validation() { + let valid = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + assert!(Mnemonic::validate(valid)); + + let invalid = "not a valid mnemonic phrase at all"; + assert!(!Mnemonic::validate(invalid)); + } +} diff --git a/src/wallet/wallet.rs b/src/wallet/wallet.rs new file mode 100644 index 0000000..ffe79ab --- /dev/null +++ b/src/wallet/wallet.rs @@ -0,0 +1,756 @@ +//! Wallet management for Bittensor. +//! +//! This module provides the main `Wallet` struct for managing coldkeys and hotkeys, +//! compatible with the Python Bittensor SDK wallet structure. +//! +//! ## Wallet Structure +//! +//! A Bittensor wallet consists of: +//! - **Coldkey**: The main key that holds funds and controls the hotkey +//! - **Hotkey**: The key used for network operations (mining, validation) +//! +//! Wallets are stored in the filesystem with the following structure: +//! ```text +//! ~/.bittensor/wallets/ +//! └── / +//! ├── coldkey # Encrypted coldkey +//! ├── coldkeypub.txt # Public coldkey SS58 address +//! └── hotkeys/ +//! └── # Encrypted hotkey +//! ``` + +use crate::wallet::keyfile::{Keyfile, KeyfileError}; +use crate::wallet::keypair::{Keypair, KeypairError}; +use crate::wallet::mnemonic::{Mnemonic, MnemonicError}; +use std::fs; +use std::io::Write; +use std::path::{Path, PathBuf}; +use thiserror::Error; + +/// Default wallet directory name under home +const WALLET_DIR_NAME: &str = ".bittensor/wallets"; + +/// Default coldkey filename +const COLDKEY_FILENAME: &str = "coldkey"; + +/// Coldkey public key filename +const COLDKEYPUB_FILENAME: &str = "coldkeypub.txt"; + +/// Hotkeys directory name +const HOTKEYS_DIR: &str = "hotkeys"; + +/// Default wallet name +#[allow(dead_code)] +const DEFAULT_WALLET_NAME: &str = "default"; + +/// Default hotkey name +const DEFAULT_HOTKEY_NAME: &str = "default"; + +/// Errors that can occur during wallet operations. +#[derive(Debug, Error)] +pub enum WalletError { + #[error("Wallet directory not found: {0}")] + DirectoryNotFound(PathBuf), + + #[error("Coldkey not found for wallet: {0}")] + ColdkeyNotFound(String), + + #[error("Hotkey not found: {0}")] + HotkeyNotFound(String), + + #[error("Keyfile error: {0}")] + Keyfile(#[from] KeyfileError), + + #[error("Keypair error: {0}")] + Keypair(#[from] KeypairError), + + #[error("Mnemonic error: {0}")] + Mnemonic(#[from] MnemonicError), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Wallet already exists: {0}")] + AlreadyExists(String), + + #[error("Invalid wallet path: {0}")] + InvalidPath(String), + + #[error("Home directory not found")] + HomeNotFound, + + #[error("Invalid name: {0}")] + InvalidName(String), +} + +/// Sanitize a name to prevent path traversal attacks. +/// +/// # Arguments +/// * `name` - The name to validate +/// +/// # Returns +/// The validated name, or an error if the name contains invalid characters. +/// +/// # Security +/// This function prevents directory traversal attacks (CWE-22) by rejecting: +/// - Path separators (`/` or `\`) +/// - Parent directory references (`..`) +/// - Empty or whitespace-only names +/// - Names starting with a dot (hidden files) +fn sanitize_name(name: &str) -> Result<&str, WalletError> { + // Reject names with path separators or traversal sequences + if name.contains('/') || name.contains('\\') || name.contains("..") { + return Err(WalletError::InvalidName(format!( + "Name '{}' contains invalid path characters", + name + ))); + } + // Reject empty or whitespace-only names + if name.trim().is_empty() { + return Err(WalletError::InvalidName( + "Name cannot be empty".to_string(), + )); + } + // Reject names starting with dots (hidden files) + if name.starts_with('.') { + return Err(WalletError::InvalidName(format!( + "Name '{}' cannot start with a dot", + name + ))); + } + Ok(name) +} + +/// A Bittensor wallet containing coldkey and hotkey. +/// +/// The wallet manages two keypairs: +/// - `coldkey`: Main key that holds funds +/// - `hotkey`: Key used for network operations +pub struct Wallet { + /// Wallet name + pub name: String, + /// Base path for wallet storage + pub path: PathBuf, + /// Name of the hotkey to use + pub hotkey_name: String, + coldkey: Keyfile, + hotkey: Keyfile, +} + +impl std::fmt::Debug for Wallet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Wallet") + .field("name", &self.name) + .field("path", &self.path) + .field("hotkey_name", &self.hotkey_name) + .finish() + } +} + +impl Wallet { + /// Create a new wallet handle without creating files on disk. + /// + /// # Arguments + /// * `name` - Wallet name (directory name under wallets/) + /// * `hotkey` - Hotkey name + /// * `path` - Optional custom base path (defaults to ~/.bittensor/wallets) + /// + /// # Returns + /// A new wallet handle, or an error if the name or hotkey contains invalid characters. + /// + /// # Security + /// Both `name` and `hotkey` are sanitized to prevent path traversal attacks. + /// Names containing `/`, `\`, `..`, or starting with `.` will be rejected. + /// + /// # Example + /// ``` + /// use bittensor_rs::wallet::Wallet; + /// let wallet = Wallet::new("my_wallet", "default", None).unwrap(); + /// ``` + pub fn new(name: &str, hotkey: &str, path: Option<&str>) -> Result { + // Sanitize inputs to prevent path traversal attacks + let name = sanitize_name(name)?; + let hotkey = sanitize_name(hotkey)?; + + let base_path = match path { + Some(p) => PathBuf::from(p), + None => default_wallet_path(), + }; + + let wallet_path = base_path.join(name); + let coldkey_path = wallet_path.join(COLDKEY_FILENAME); + let hotkey_path = wallet_path.join(HOTKEYS_DIR).join(hotkey); + + Ok(Self { + name: name.to_string(), + path: wallet_path, + hotkey_name: hotkey.to_string(), + coldkey: Keyfile::new(coldkey_path), + hotkey: Keyfile::new(hotkey_path), + }) + } + + /// Create a new wallet with both coldkey and hotkey. + /// + /// # Arguments + /// * `name` - Wallet name + /// * `hotkey` - Hotkey name + /// * `password` - Optional password for encryption + /// + /// # Returns + /// A new wallet with generated keys, or an error if creation fails. + /// + /// # Example + /// ```no_run + /// use bittensor_rs::wallet::Wallet; + /// let wallet = Wallet::create("new_wallet", "default", Some("password")).unwrap(); + /// ``` + pub fn create(name: &str, hotkey: &str, password: Option<&str>) -> Result { + let mut wallet = Self::new(name, hotkey, None)?; + + // Create coldkey + wallet.create_coldkey(password, None, false)?; + + // Create hotkey + wallet.create_hotkey(password, None, false)?; + + Ok(wallet) + } + + /// Create a new wallet with both coldkey and hotkey at a custom path. + /// + /// # Arguments + /// * `name` - Wallet name + /// * `hotkey` - Hotkey name + /// * `path` - Custom base path for wallet storage + /// * `password` - Optional password for encryption + /// + /// # Returns + /// A new wallet with generated keys. + pub fn create_at_path( + name: &str, + hotkey: &str, + path: &str, + password: Option<&str>, + ) -> Result { + let mut wallet = Self::new(name, hotkey, Some(path))?; + + wallet.create_coldkey(password, None, false)?; + wallet.create_hotkey(password, None, false)?; + + Ok(wallet) + } + + /// Create or regenerate the coldkey. + /// + /// # Arguments + /// * `password` - Optional password for encryption + /// * `mnemonic` - Optional mnemonic for recovery (generates new if None) + /// * `overwrite` - Whether to overwrite existing coldkey + /// + /// # Returns + /// The mnemonic phrase used (save this for recovery!). + pub fn create_coldkey( + &mut self, + password: Option<&str>, + mnemonic: Option<&str>, + overwrite: bool, + ) -> Result { + let mnemonic_obj = match mnemonic { + Some(phrase) => Mnemonic::from_phrase(phrase)?, + None => Mnemonic::generate(), + }; + + let keypair = Keypair::from_mnemonic_obj(&mnemonic_obj, password)?; + + // Store the mnemonic phrase before potentially moving it + let phrase = mnemonic_obj.phrase().to_string(); + + // Ensure wallet directory exists + fs::create_dir_all(&self.path)?; + + // Save coldkey + self.coldkey.set_keypair(keypair.clone(), password, overwrite)?; + + // Save public key file + self.save_coldkey_pub(&keypair)?; + + Ok(phrase) + } + + /// Create or regenerate the hotkey. + /// + /// # Arguments + /// * `password` - Optional password for encryption + /// * `mnemonic` - Optional mnemonic for recovery (generates new if None) + /// * `overwrite` - Whether to overwrite existing hotkey + /// + /// # Returns + /// The mnemonic phrase used (save this for recovery!). + pub fn create_hotkey( + &mut self, + password: Option<&str>, + mnemonic: Option<&str>, + overwrite: bool, + ) -> Result { + let mnemonic_obj = match mnemonic { + Some(phrase) => Mnemonic::from_phrase(phrase)?, + None => Mnemonic::generate(), + }; + + let keypair = Keypair::from_mnemonic_obj(&mnemonic_obj, password)?; + let phrase = mnemonic_obj.phrase().to_string(); + + // Ensure hotkeys directory exists + let hotkeys_dir = self.path.join(HOTKEYS_DIR); + fs::create_dir_all(&hotkeys_dir)?; + + // Save hotkey + self.hotkey.set_keypair(keypair, password, overwrite)?; + + Ok(phrase) + } + + /// Get a reference to the coldkey keyfile. + pub fn coldkey(&self) -> &Keyfile { + &self.coldkey + } + + /// Get a reference to the hotkey keyfile. + pub fn hotkey(&self) -> &Keyfile { + &self.hotkey + } + + /// Get the coldkey keypair. + /// + /// # Arguments + /// * `password` - Password for decryption (if encrypted) + pub fn coldkey_keypair(&self, password: Option<&str>) -> Result { + self.coldkey.get_keypair(password).map_err(WalletError::Keyfile) + } + + /// Get the hotkey keypair. + /// + /// # Arguments + /// * `password` - Password for decryption (if encrypted) + pub fn hotkey_keypair(&self, password: Option<&str>) -> Result { + self.hotkey.get_keypair(password).map_err(WalletError::Keyfile) + } + + /// Get the coldkey SS58 address. + /// + /// This reads from the coldkeypub.txt file if available, otherwise + /// decrypts the coldkey to get the address. + pub fn coldkey_ss58(&self, password: Option<&str>) -> Result { + // Try to read from coldkeypub.txt first + let pub_path = self.path.join(COLDKEYPUB_FILENAME); + if pub_path.exists() { + if let Ok(content) = fs::read_to_string(&pub_path) { + let address = content.trim().to_string(); + if !address.is_empty() { + return Ok(address); + } + } + } + + // Fall back to decrypting coldkey + let keypair = self.coldkey_keypair(password)?; + Ok(keypair.ss58_address().to_string()) + } + + /// Get the hotkey SS58 address. + pub fn hotkey_ss58(&self, password: Option<&str>) -> Result { + let keypair = self.hotkey_keypair(password)?; + Ok(keypair.ss58_address().to_string()) + } + + /// Check if the coldkey exists on disk. + pub fn coldkey_exists(&self) -> bool { + self.coldkey.exists() + } + + /// Check if the hotkey exists on disk. + pub fn hotkey_exists(&self) -> bool { + self.hotkey.exists() + } + + /// Check if both coldkey and hotkey exist. + pub fn exists(&self) -> bool { + self.coldkey_exists() && self.hotkey_exists() + } + + /// Regenerate a wallet from a coldkey mnemonic. + /// + /// # Arguments + /// * `name` - Wallet name + /// * `mnemonic` - The coldkey mnemonic phrase + /// * `password` - Optional password for derivation and encryption + /// + /// # Returns + /// A wallet with the regenerated coldkey (hotkey must be created separately). + pub fn regenerate_coldkey( + name: &str, + mnemonic: &str, + password: Option<&str>, + ) -> Result { + let mut wallet = Self::new(name, DEFAULT_HOTKEY_NAME, None)?; + wallet.create_coldkey(password, Some(mnemonic), true)?; + Ok(wallet) + } + + /// Regenerate a hotkey from a mnemonic. + /// + /// # Arguments + /// * `name` - Wallet name + /// * `hotkey_name` - Hotkey name + /// * `mnemonic` - The hotkey mnemonic phrase + /// * `password` - Optional password for derivation and encryption + /// + /// # Returns + /// A wallet handle with the regenerated hotkey. + pub fn regenerate_hotkey( + name: &str, + hotkey_name: &str, + mnemonic: &str, + password: Option<&str>, + ) -> Result { + let mut wallet = Self::new(name, hotkey_name, None)?; + wallet.create_hotkey(password, Some(mnemonic), true)?; + Ok(wallet) + } + + /// List all hotkeys for this wallet. + /// + /// # Returns + /// A list of hotkey names. + pub fn list_hotkeys(&self) -> Result, WalletError> { + let hotkeys_dir = self.path.join(HOTKEYS_DIR); + if !hotkeys_dir.exists() { + return Ok(Vec::new()); + } + + let mut hotkeys = Vec::new(); + for entry in fs::read_dir(&hotkeys_dir)? { + let entry = entry?; + if entry.file_type()?.is_file() { + if let Some(name) = entry.file_name().to_str() { + hotkeys.push(name.to_string()); + } + } + } + + hotkeys.sort(); + Ok(hotkeys) + } + + /// Switch to a different hotkey. + /// + /// # Arguments + /// * `hotkey_name` - Name of the hotkey to switch to + /// + /// # Returns + /// Ok(()) on success, or an error if the hotkey name is invalid. + /// + /// # Security + /// The hotkey name is sanitized to prevent path traversal attacks. + pub fn use_hotkey(&mut self, hotkey_name: &str) -> Result<(), WalletError> { + let hotkey_name = sanitize_name(hotkey_name)?; + self.hotkey_name = hotkey_name.to_string(); + let hotkey_path = self.path.join(HOTKEYS_DIR).join(hotkey_name); + self.hotkey = Keyfile::new(hotkey_path); + Ok(()) + } + + /// Save the coldkey public address to coldkeypub.txt. + /// + /// # Security + /// The file is created with restrictive permissions (0o600 on Unix) + /// to prevent unauthorized access to the public key. + fn save_coldkey_pub(&self, keypair: &Keypair) -> Result<(), WalletError> { + let pub_path = self.path.join(COLDKEYPUB_FILENAME); + let mut file = fs::File::create(&pub_path)?; + writeln!(file, "{}", keypair.ss58_address())?; + + // Set restrictive permissions on Unix (readable by owner only) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let permissions = fs::Permissions::from_mode(0o600); + fs::set_permissions(&pub_path, permissions)?; + } + + Ok(()) + } +} + +/// Get the default wallet path (~/.bittensor/wallets). +/// +/// # Returns +/// The default wallet directory path. +/// +/// # Panics +/// Panics if the home directory cannot be determined. +pub fn default_wallet_path() -> PathBuf { + dirs::home_dir() + .map(|home| home.join(WALLET_DIR_NAME)) + .unwrap_or_else(|| PathBuf::from(WALLET_DIR_NAME)) +} + +/// Get the full path to a specific wallet. +/// +/// # Arguments +/// * `name` - Wallet name +/// +/// # Returns +/// The full path to the wallet directory. +pub fn wallet_path(name: &str) -> PathBuf { + default_wallet_path().join(name) +} + +/// List all wallets in the default wallet directory. +/// +/// # Returns +/// A list of wallet names. +pub fn list_wallets() -> Result, WalletError> { + list_wallets_at(&default_wallet_path()) +} + +/// List all wallets at a specific path. +/// +/// # Arguments +/// * `path` - The wallet directory path +/// +/// # Returns +/// A list of wallet names. +pub fn list_wallets_at(path: &Path) -> Result, WalletError> { + if !path.exists() { + return Ok(Vec::new()); + } + + let mut wallets = Vec::new(); + for entry in fs::read_dir(path)? { + let entry = entry?; + if entry.file_type()?.is_dir() { + // Check if it has a coldkey (makes it a valid wallet) + let coldkey_path = entry.path().join(COLDKEY_FILENAME); + if coldkey_path.exists() { + if let Some(name) = entry.file_name().to_str() { + wallets.push(name.to_string()); + } + } + } + } + + wallets.sort(); + Ok(wallets) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn test_wallet_new() { + let wallet = Wallet::new("test_wallet", "test_hotkey", None).unwrap(); + assert_eq!(wallet.name, "test_wallet"); + assert_eq!(wallet.hotkey_name, "test_hotkey"); + } + + #[test] + fn test_wallet_create() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + let wallet = Wallet::create_at_path("test_wallet", "default", base_path, None).unwrap(); + + assert!(wallet.coldkey_exists()); + assert!(wallet.hotkey_exists()); + assert!(wallet.exists()); + } + + #[test] + fn test_wallet_create_with_password() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + let password = "test_password"; + + // Create wallet with password + let wallet = + Wallet::create_at_path("test_wallet", "default", base_path, Some(password)).unwrap(); + + // Should be able to get keypairs with password (from cached version) + let coldkey = wallet.coldkey_keypair(Some(password)).unwrap(); + let hotkey = wallet.hotkey_keypair(Some(password)).unwrap(); + + assert!(!coldkey.ss58_address().is_empty()); + assert!(!hotkey.ss58_address().is_empty()); + + // Create a fresh wallet instance pointing to the same files + // This tests that reading from disk requires password + let wallet2 = Wallet::new("test_wallet", "default", Some(base_path)).unwrap(); + + // Should fail without password when reading from disk + assert!(wallet2.coldkey_keypair(None).is_err()); + assert!(wallet2.hotkey_keypair(None).is_err()); + + // Should succeed with correct password + assert!(wallet2.coldkey_keypair(Some(password)).is_ok()); + assert!(wallet2.hotkey_keypair(Some(password)).is_ok()); + } + + #[test] + fn test_wallet_regenerate_coldkey() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"; + + let mut wallet = Wallet::new("test_wallet", "default", Some(base_path)).unwrap(); + let returned_mnemonic = wallet.create_coldkey(None, Some(mnemonic), false).unwrap(); + + assert_eq!(returned_mnemonic, mnemonic); + + // Should be deterministic + let keypair1 = wallet.coldkey_keypair(None).unwrap(); + + // Create another wallet with same mnemonic + let mut wallet2 = Wallet::new("test_wallet2", "default", Some(base_path)).unwrap(); + wallet2.create_coldkey(None, Some(mnemonic), false).unwrap(); + let keypair2 = wallet2.coldkey_keypair(None).unwrap(); + + assert_eq!(keypair1.ss58_address(), keypair2.ss58_address()); + } + + #[test] + fn test_wallet_list_hotkeys() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + let mut wallet = Wallet::new("test_wallet", "hotkey1", Some(base_path)).unwrap(); + wallet.create_coldkey(None, None, false).unwrap(); + wallet.create_hotkey(None, None, false).unwrap(); + + // Create second hotkey + wallet.use_hotkey("hotkey2").unwrap(); + wallet.create_hotkey(None, None, false).unwrap(); + + let hotkeys = wallet.list_hotkeys().unwrap(); + assert_eq!(hotkeys.len(), 2); + assert!(hotkeys.contains(&"hotkey1".to_string())); + assert!(hotkeys.contains(&"hotkey2".to_string())); + } + + #[test] + fn test_list_wallets() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + // Create multiple wallets + Wallet::create_at_path("wallet1", "default", base_path, None).unwrap(); + Wallet::create_at_path("wallet2", "default", base_path, None).unwrap(); + + let wallets = list_wallets_at(dir.path()).unwrap(); + assert_eq!(wallets.len(), 2); + assert!(wallets.contains(&"wallet1".to_string())); + assert!(wallets.contains(&"wallet2".to_string())); + } + + #[test] + fn test_coldkey_ss58() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + let wallet = Wallet::create_at_path("test_wallet", "default", base_path, None).unwrap(); + + let ss58 = wallet.coldkey_ss58(None).unwrap(); + assert!(!ss58.is_empty()); + assert!(ss58.starts_with('5')); // Substrate SS58 format + } + + #[test] + fn test_wallet_use_hotkey() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + let mut wallet = Wallet::create_at_path("test_wallet", "hotkey1", base_path, None).unwrap(); + + assert_eq!(wallet.hotkey_name, "hotkey1"); + + wallet.use_hotkey("hotkey2").unwrap(); + assert_eq!(wallet.hotkey_name, "hotkey2"); + assert!(!wallet.hotkey_exists()); // hotkey2 doesn't exist yet + } + + #[test] + fn test_wallet_path_functions() { + let default_path = default_wallet_path(); + assert!(default_path.ends_with(".bittensor/wallets")); + + let specific_path = wallet_path("my_wallet"); + assert!(specific_path.ends_with("my_wallet")); + } + + #[test] + fn test_coldkeypub_txt() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + let wallet = Wallet::create_at_path("test_wallet", "default", base_path, None).unwrap(); + + // Check coldkeypub.txt was created + let pub_path = dir.path().join("test_wallet").join("coldkeypub.txt"); + assert!(pub_path.exists()); + + // Content should match SS58 address + let content = fs::read_to_string(&pub_path).unwrap(); + let ss58 = wallet.coldkey_ss58(None).unwrap(); + assert_eq!(content.trim(), ss58); + } + + #[test] + fn test_path_traversal_prevention() { + // These should all fail due to path traversal protection + assert!(sanitize_name("../evil").is_err()); + assert!(sanitize_name("foo/../bar").is_err()); + assert!(sanitize_name("foo/bar").is_err()); + assert!(sanitize_name("foo\\bar").is_err()); + assert!(sanitize_name(".hidden").is_err()); + assert!(sanitize_name("").is_err()); + assert!(sanitize_name(" ").is_err()); + + // These should succeed + assert!(sanitize_name("valid_name").is_ok()); + assert!(sanitize_name("wallet-1").is_ok()); + assert!(sanitize_name("MyWallet").is_ok()); + } + + #[test] + fn test_wallet_new_rejects_path_traversal() { + // Wallet::new should reject path traversal attempts + assert!(Wallet::new("../evil", "default", None).is_err()); + assert!(Wallet::new("good", "../evil", None).is_err()); + assert!(Wallet::new(".hidden", "default", None).is_err()); + assert!(Wallet::new("good", ".hidden", None).is_err()); + assert!(Wallet::new("foo/bar", "default", None).is_err()); + assert!(Wallet::new("good", "foo/bar", None).is_err()); + + // Valid names should work + assert!(Wallet::new("valid_wallet", "valid_hotkey", None).is_ok()); + } + + #[test] + fn test_use_hotkey_rejects_path_traversal() { + let dir = tempdir().unwrap(); + let base_path = dir.path().to_str().unwrap(); + + let mut wallet = Wallet::create_at_path("test_wallet", "default", base_path, None).unwrap(); + + // Should reject path traversal in use_hotkey + assert!(wallet.use_hotkey("../evil").is_err()); + assert!(wallet.use_hotkey(".hidden").is_err()); + assert!(wallet.use_hotkey("foo/bar").is_err()); + + // Valid name should work + assert!(wallet.use_hotkey("valid_hotkey").is_ok()); + } +}