-
Notifications
You must be signed in to change notification settings - Fork 86
Full End-To-End Inference Flow and Gateway Node Implementation #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
09b4132 to
ad343a2
Compare
d1f1169 to
2dbbe93
Compare
pefontana
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice @samherring99 !
I noticed two thing that may be easy to change:
- the just command doesnt work, because the tmux session dont inherit the
nix develop .#dev-pythonneither the python venv
To run it I have to run in two diferent terminal:
nix develop .#dev-python
source .venv/bin/activate
PSYCHE_GATEWAY_BOOTSTRAP_FILE=psyche-gateway-peer.json LIBTORCH_USE_PYTORCH=1 RUST_LOG=info cargo run --bin psyche-inference-node -- --model-name NousResearch/Hermes-4-14B --discovery-mode n0 --relay-kind n0
nix develop .#dev-python
source .venv/bin/activate
PSYCHE_GATEWAY_ENDPOINT_FILE=psyche-gateway-peer.json RUST_LOG=info cargo run --bin gateway-node --features gateway -- --discovery-mode n0 --relay-kind n0
- With the command
PSYCHE_GATEWAY_BOOTSTRAP_FILE=psyche-gateway-peer.json LIBTORCH_USE_PYTORCH=1 RUST_LOG=info cargo run --bin psyche-inference-node -- --model-name NousResearch/Hermes-4-14B --discovery-mode n0 --relay-kind n0
I am getting a NumPy error
ImportError: Numba needs NumPy 2.2 or less. Got NumPy 2.3
I tried to install Numpy 2.2 but the Path is still set to the nix version one
Maybe we can update Numba to fix this and make it easier to run?
Could you share the tmux errors you're seeing? It might come down to versioning issues between our setups but the just command starts up tmux with As for the NumPy errors I think we will resolve this with vllm included in the nix packaging ;) 🤞 - but I will look into it regardless |
2dbbe93 to
937d280
Compare
Sure. And the tmux ouputs:
|
|
@samherring99 |
937d280 to
0af54f3
Compare
FWIW this looks like a CUDA / NCCL error, I'm guessing this is also related to venv / torch / vllm issues. will tag @arilotter for confirmation / final review 🙂 |
0af54f3 to
958793a
Compare
| let nodes = state.available_nodes.read().await; | ||
| if nodes.is_empty() { | ||
| return Err(AppError::NoNodesAvailable); | ||
| } | ||
|
|
||
| let node = nodes.values().next().unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking that the vec is not empty and then calling unwrap, I think we can do:
let nodes = state.available_nodes.read().await;
let node = nodes.values().next().ok_or(AppError::NoNodesAvailable)?;It’s not really important since we’re unlikely to panic, but I think this is more idiomatic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweet, I'll test this out and update.
| let _ = tx.send(response).await; | ||
| } | ||
| } | ||
| Err(e) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If something fails in the send_inference_request call, we’re not cleaning up the request_id from pending_requests. Is that handled somewhere else? Not sure whether it’s correct to remove it if something fails there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good callout, I'll ensure this is handled correctly and will update here.
| peer_id.fmt_short() | ||
| ); | ||
|
|
||
| tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this sleep necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a debug measure I forgot to remove 🙃 thanks for catching it
I lied, this is necessary because we need to wait to give time for the bytes to flush through the network before the connection was dropped and we need to wait for the receiver to actually read all the data. I'll add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, maybe you can do connection.closed().await;? Not really sure though, I didn’t try it. I’m just trying to avoid future problems where the receiver takes more than 100 ms to read things and we end up having the same error, but it’s not as high priority 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did think about this, will likely address in a later PR if it becomes and issue.
| info!("Capabilities: {:?}", capabilities); | ||
|
|
||
| // read bootstrap peers from multiple sources in priority order | ||
| let bootstrap_peers: Vec<EndpointAddr> = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have almost the same logic in the main.rs file of the crate. Can we extract it to an aux function in handle the difference on the implementations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, was being lazy about this, will move to a new lib.rs file for shared implementation.
| pub request_id: String, | ||
| pub prompt: String, | ||
| pub messages: Vec<ChatMessage>, | ||
| #[serde(default = "default_max_tokens")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not strictly related to this PR but I think both protocol and gateway-node uses the same default functions. Can they be different at some point? Also you can use the default value directly without using the default functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, to reduce scope I'll probably tackle this in a later PR if thats okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, no worries
| let model_name = req.model.clone().unwrap_or_else(|| node.model_name.clone()); | ||
| info!( | ||
| "Routing request to node: {} (model: {})", | ||
| node.peer_id.fmt_short(), | ||
| node.model_name | ||
| ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more of a question, but here we get an inference node from the list and only use its model_name. Then, in the run_gateway function, we select another node from the list as target_node, which is the one we actually route the request to. I might be misunderstanding something, but wouldn’t it be better to select a single node and route the request to that one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I did this because I wasn't passing the peer ID through the channel as part of an InferenceMessage type, but I'll make that change to include it so we select one node only and route the request there.
36eea32 to
91248ac
Compare
| send.finish()?; | ||
|
|
||
| // wait for a moment to let the connection flush all the bytes to the reciever | ||
| tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conn.close(0u32.into(), b"bye!");
endpoint.close().await;this should flush the connection buffer before returning.
see https://github.com/n0-computer/iroh/blob/6ad5ac4238a3cc101791922167aab952d4c99c1e/iroh/examples/echo.rs#L65
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you will test this out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, AFAICT there's no async way to wait for QUIC to flush without closing the entire endpoint... I'm fine with a time delay but would an adaptive delay based on payload size be more reasonable / future proof? I think we want the endpoint to stay open to accept future requests, and according to the comments in what you linked that seems like a requirement 🙁
| http::StatusCode, | ||
| response::{IntoResponse, Response}, | ||
| routing::post, | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can group these all in one big use that's feature-flagged ? or just.. rip out the gateway feature IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes should be no issue 😎
| // Spawn task to handle P2P connection | ||
| let endpoint = network.router().endpoint().clone(); | ||
| let state_clone = state.clone(); | ||
| tokio::spawn(async move { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
task here without any tracking is a little scary - should we keep these in some task pool, add timeouts, monitor, etc? once we get a request we simply throw this into the tokio task pool and can't tell if something works or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This worked with tokio::task::JoinSet::new() 🙂
…pes, adding initial skeleton of inference-node main loop, wiring inference-node up to iroh gossip updates, updating Cargo toml
… param type to be optional, single protocol, and generic, adding justfile commands and test script to test inference
…P response handling
…rotocolHandler method and custom protocol code path
…g to single node selection for request routing
91248ac to
ef35e92
Compare
ef35e92 to
e4cb343
Compare
This PR provides the following things:
architectures/inference-only/inference-node/src/bin/gateway-node.rs, which includes thehandle_inferencemethod to forward inference requests over the P2P network (using iroh's bidrectional streams), and the Gateway node performs the discovery of available inference nodes through iroh's gossip. This is implemented in the gateway node binary atarchitectures/inference-only/inference-node/Cargo.tomlshared/inference/src/protocol_handler.rswhich implements iroh's ProtocolHandler trait to accept incoming inference requests over the direct P2P connection.shared/inference/src/protocol.rsto allow for OpenAPI API style/v1/chat/completionsmessages and some tests.python/python/psyche/vllm/rust_bridge.pyto use OpenAPI API style/v1/chat/completionsmessages. These changes are reflected inshared/inference/src/node.rs,shared/inference/src/vllm.rs, andshared/inference/src/protocol.rswith some testing.architectures/inference-only/inference-node/src/main.rscan now read bootstrap peers from a given file, and rebroadcast availability over gossip every 30 seconds.shared/network/src/lib.rsandshared/network/src/router.rsto use an internalinit_internalmethod, plus a method calledinit_with_custom_protocolto use a custom protocol on initialization.axumandtowerto dependencies inCargo.tomlandarchitectures/inference-only/inference-node/Cargo.toml.scripts/test-inference-e2e.shto test the end to end inference flowjustcommands added to thejustfile:inference-node, gateway-node, inference-stack, test-inference, test-inference-e2eTesting (requires a venv with vLLM installed as of now):
Output
The above commands will start up 1 gateway node and 1 inference node, will allow the gateway node to write its endpoint ID to a temp file where the inference node can read it and bootstrap from it, and then will spin up an endpoint at
localhost:8000/v1/chat/completionsto receive requests to be forwarded to the inference node.As always, any questions, comments, or concerns with how this is set up are welcome 😄 - streaming, checkpoint updating, and load balancing are all on the future roadmap for this effort, as well as discussion on how to correctly bootstrap from our gateway nodes.