diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py index 42d928403..47477529c 100644 --- a/fastchat/serve/controller.py +++ b/fastchat/serve/controller.py @@ -346,10 +346,16 @@ async def worker_api_get_status(request: Request): @app.get("/test_connection") -async def worker_api_get_status(request: Request): +async def test_connection(request: Request): return "success" +@app.get("/health") +async def health_check(): + """Health check endpoint for load balancers and orchestration systems.""" + return {"status": "ok"} + + def create_controller(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 8941c6ecb..35a21549b 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -224,7 +224,8 @@ def get_model_list(controller_url, register_api_endpoint_file, vision_arena): # Add models from the API providers if register_api_endpoint_file: - api_endpoint_info = json.load(open(register_api_endpoint_file)) + with open(register_api_endpoint_file) as f: + api_endpoint_info = json.load(f) for mdl, mdl_dict in api_endpoint_info.items(): mdl_vision = mdl_dict.get("vision-arena", False) mdl_text = mdl_dict.get("text-arena", True) diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 683a78556..8e04f50a5 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -90,7 +90,7 @@ def __init__( debug=debug, ) self.device = device - if self.tokenizer.pad_token == None: + if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.context_len = get_context_length(self.model.config) self.generate_stream_func = get_generate_stream_function(self.model, model_path) diff --git a/fastchat/serve/monitor/clean_battle_data.py b/fastchat/serve/monitor/clean_battle_data.py index 270f981cc..3aa3fd88e 100644 --- a/fastchat/serve/monitor/clean_battle_data.py +++ b/fastchat/serve/monitor/clean_battle_data.py @@ -385,7 +385,11 @@ def clean_battle_data( args = parser.parse_args() log_files = get_log_files(args.max_num_files) - ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None + if args.ban_ip_file: + with open(args.ban_ip_file) as f: + ban_ip_list = json.load(f) + else: + ban_ip_list = None battles = clean_battle_data( log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip diff --git a/fastchat/serve/monitor/intersect_conv_file.py b/fastchat/serve/monitor/intersect_conv_file.py index 9eadd7cd5..ee2d68d2c 100644 --- a/fastchat/serve/monitor/intersect_conv_file.py +++ b/fastchat/serve/monitor/intersect_conv_file.py @@ -15,11 +15,14 @@ parser.add_argument("--out-file", type=str, default="intersect.json") args = parser.parse_args() - conv_id_objs = json.load(open(args.conv_id, "r")) + with open(args.conv_id, "r") as f: + conv_id_objs = json.load(f) conv_ids = set(x["conversation_id"] for x in conv_id_objs) - objs = json.load(open(args.input, "r")) + with open(args.input, "r") as f: + objs = json.load(f) after_objs = [x for x in objs if x["conversation_id"] in conv_ids] print(f"#in: {len(objs)}, #out: {len(after_objs)}") - json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False) + with open(args.out_file, "w") as f: + json.dump(after_objs, f, indent=2, ensure_ascii=False) diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index 462e38187..04727c8e8 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -86,7 +86,11 @@ def update_elo_components( # Leaderboard if elo_results_file is None: # Do live update - ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None + if ban_ip_file: + with open(ban_ip_file) as f: + ban_ip_list = json.load(f) + else: + ban_ip_list = None battles = clean_battle_data( log_files, exclude_model_names, ban_ip_list=ban_ip_list ) diff --git a/fastchat/serve/monitor/tag_openai_moderation.py b/fastchat/serve/monitor/tag_openai_moderation.py index b80703388..77bb2f831 100644 --- a/fastchat/serve/monitor/tag_openai_moderation.py +++ b/fastchat/serve/monitor/tag_openai_moderation.py @@ -46,7 +46,8 @@ def tag_openai_moderation(x): parser.add_argument("--first-n", type=int) args = parser.parse_args() - battles = json.load(open(args.input)) + with open(args.input) as f: + battles = json.load(f) if args.first_n: battles = battles[: args.first_n] diff --git a/fastchat/serve/monitor/topic_clustering.py b/fastchat/serve/monitor/topic_clustering.py index 3d58e56bf..1c45b5b2c 100644 --- a/fastchat/serve/monitor/topic_clustering.py +++ b/fastchat/serve/monitor/topic_clustering.py @@ -34,7 +34,8 @@ def read_texts(input_file, min_length, max_length, english_only): visited = set() texts = [] - lines = json.load(open(input_file, "r")) + with open(input_file, "r") as f: + lines = json.load(f) for l in tqdm(lines): if "text" in l: diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index a6ffee96b..4ce49147e 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -304,7 +304,7 @@ async def get_gen_params( if msg_role == "system": conv.set_system_message(message["content"]) elif msg_role == "user": - if type(message["content"]) == list: + if isinstance(message["content"], list): image_list = [ item["image_url"]["url"] for item in message["content"] @@ -394,6 +394,12 @@ async def get_conv(model_name: str, worker_addr: str): return conv_template +@app.get("/health") +async def health_check(): + """Health check endpoint for load balancers and orchestration systems.""" + return {"status": "ok"} + + @app.get("/v1/models", dependencies=[Depends(check_api_key)]) async def show_available_models(): controller_address = app_settings.controller_address