Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,16 @@ async def worker_api_get_status(request: Request):


@app.get("/test_connection")
async def worker_api_get_status(request: Request):
async def test_connection(request: Request):
return "success"


@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers and orchestration systems."""
return {"status": "ok"}


def create_controller():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
debug=debug,
)
self.device = device
if self.tokenizer.pad_token == None:
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.context_len = get_context_length(self.model.config)
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
Expand Down
8 changes: 7 additions & 1 deletion fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ async def get_gen_params(
if msg_role == "system":
conv.set_system_message(message["content"])
elif msg_role == "user":
if type(message["content"]) == list:
if isinstance(message["content"], list):
image_list = [
item["image_url"]["url"]
for item in message["content"]
Expand Down Expand Up @@ -394,6 +394,12 @@ async def get_conv(model_name: str, worker_addr: str):
return conv_template


@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers and orchestration systems."""
return {"status": "ok"}


@app.get("/v1/models", dependencies=[Depends(check_api_key)])
async def show_available_models():
controller_address = app_settings.controller_address
Expand Down