1818import functools
1919from copy import deepcopy
2020from pathlib import Path
21- from typing import Callable , Sequence , Dict , List , Any , Tuple
21+ from typing import Callable , Sequence , Dict , List , Any , Tuple , Optional
2222from jsonschema import validate
2323from yaml import load as yaml_load
2424
@@ -43,7 +43,8 @@ class EntryPoint(EntryPointABC):
4343 parse_env = True
4444
4545 argparse_check_required = False
46- argparse_noflag = None
46+ argparse_noflag : Optional [str ] = None
47+ _config_file_parser_map : Dict [str , Callable [[Path ], Dict [str , Any ]]] = {}
4748
4849 def _check_schema (self ) -> None :
4950 if self .schema is not None :
@@ -248,13 +249,26 @@ def parse_yaml_configfile_args(self, p: Path) -> Dict[str, Any]:
248249 return res
249250 return result
250251
252+ def regist_config_file_parser (self , file_name : str ) -> Callable [[Callable [[Path ], Dict [str , Any ]]], Callable [[Path ], Dict [str , Any ]]]:
253+ def decorate (func : Callable [[Path ], Dict [str , Any ]]) -> Callable [[Path ], Dict [str , Any ]]:
254+ @functools .wraps (func )
255+ def wrap (p : Path ) -> Dict [str , Any ]:
256+ return func (p )
257+ self ._config_file_parser_map [file_name ] = func
258+ return wrap
259+ return decorate
260+
251261 def parse_configfile_args (self ) -> Dict [str , Any ]:
252262 if not self .default_config_file_paths :
253263 return {}
254264 if not self .load_all_config_file :
255265 for p_str in self .default_config_file_paths :
256266 p = Path (p_str )
257267 if p .is_file ():
268+ parfunc = self ._config_file_parser_map .get (p .name )
269+ if parfunc :
270+ print ("&&&&&&" )
271+ return parfunc (p )
258272 if p .suffix == ".json" :
259273 return self .parse_json_configfile_args (p )
260274 elif p .suffix == ".yml" :
@@ -269,12 +283,17 @@ def parse_configfile_args(self) -> Dict[str, Any]:
269283 for p_str in self .default_config_file_paths :
270284 p = Path (p_str )
271285 if p .is_file ():
272- if p . suffix == ".json" :
273- result . update ( self . parse_json_configfile_args ( p ))
274- elif p . suffix == ".yml" :
275- result .update (self . parse_yaml_configfile_args (p ))
286+ parfunc = self . _config_file_parser_map . get ( p . name )
287+ if parfunc :
288+ print ( "&&&&&&@@@" )
289+ result .update (parfunc (p ))
276290 else :
277- warnings .warn (f"跳过不支持的配置格式的文件{ str (p )} " )
291+ if p .suffix == ".json" :
292+ result .update (self .parse_json_configfile_args (p ))
293+ elif p .suffix == ".yml" :
294+ result .update (self .parse_yaml_configfile_args (p ))
295+ else :
296+ warnings .warn (f"跳过不支持的配置格式的文件{ str (p )} " )
278297 return result
279298
280299 def validat_config (self ) -> bool :
0 commit comments