Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,16 @@ async def worker_api_get_status(request: Request):


@app.get("/test_connection")
async def worker_api_get_status(request: Request):
async def test_connection(request: Request):
return "success"


@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers and orchestration systems."""
return {"status": "ok"}


def create_controller():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
Expand Down
3 changes: 2 additions & 1 deletion fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def get_model_list(controller_url, register_api_endpoint_file, vision_arena):

# Add models from the API providers
if register_api_endpoint_file:
api_endpoint_info = json.load(open(register_api_endpoint_file))
with open(register_api_endpoint_file) as f:
api_endpoint_info = json.load(f)
for mdl, mdl_dict in api_endpoint_info.items():
mdl_vision = mdl_dict.get("vision-arena", False)
mdl_text = mdl_dict.get("text-arena", True)
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
debug=debug,
)
self.device = device
if self.tokenizer.pad_token == None:
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.context_len = get_context_length(self.model.config)
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
Expand Down
6 changes: 5 additions & 1 deletion fastchat/serve/monitor/clean_battle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,11 @@ def clean_battle_data(
args = parser.parse_args()

log_files = get_log_files(args.max_num_files)
ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None
if args.ban_ip_file:
with open(args.ban_ip_file) as f:
ban_ip_list = json.load(f)
else:
ban_ip_list = None

battles = clean_battle_data(
log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip
Expand Down
9 changes: 6 additions & 3 deletions fastchat/serve/monitor/intersect_conv_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
parser.add_argument("--out-file", type=str, default="intersect.json")
args = parser.parse_args()

conv_id_objs = json.load(open(args.conv_id, "r"))
with open(args.conv_id, "r") as f:
conv_id_objs = json.load(f)
conv_ids = set(x["conversation_id"] for x in conv_id_objs)

objs = json.load(open(args.input, "r"))
with open(args.input, "r") as f:
objs = json.load(f)
after_objs = [x for x in objs if x["conversation_id"] in conv_ids]

print(f"#in: {len(objs)}, #out: {len(after_objs)}")
json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False)
with open(args.out_file, "w") as f:
json.dump(after_objs, f, indent=2, ensure_ascii=False)
6 changes: 5 additions & 1 deletion fastchat/serve/monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def update_elo_components(

# Leaderboard
if elo_results_file is None: # Do live update
ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None
if ban_ip_file:
with open(ban_ip_file) as f:
ban_ip_list = json.load(f)
else:
ban_ip_list = None
battles = clean_battle_data(
log_files, exclude_model_names, ban_ip_list=ban_ip_list
)
Expand Down
3 changes: 2 additions & 1 deletion fastchat/serve/monitor/tag_openai_moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def tag_openai_moderation(x):
parser.add_argument("--first-n", type=int)
args = parser.parse_args()

battles = json.load(open(args.input))
with open(args.input) as f:
battles = json.load(f)

if args.first_n:
battles = battles[: args.first_n]
Expand Down
3 changes: 2 additions & 1 deletion fastchat/serve/monitor/topic_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def read_texts(input_file, min_length, max_length, english_only):
visited = set()
texts = []

lines = json.load(open(input_file, "r"))
with open(input_file, "r") as f:
lines = json.load(f)

for l in tqdm(lines):
if "text" in l:
Expand Down
8 changes: 7 additions & 1 deletion fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ async def get_gen_params(
if msg_role == "system":
conv.set_system_message(message["content"])
elif msg_role == "user":
if type(message["content"]) == list:
if isinstance(message["content"], list):
image_list = [
item["image_url"]["url"]
for item in message["content"]
Expand Down Expand Up @@ -394,6 +394,12 @@ async def get_conv(model_name: str, worker_addr: str):
return conv_template


@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers and orchestration systems."""
return {"status": "ok"}


@app.get("/v1/models", dependencies=[Depends(check_api_key)])
async def show_available_models():
controller_address = app_settings.controller_address
Expand Down