diff --git a/README.md b/README.md index 4315f75..154cab3 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,11 @@ https://www.surveymonkey.de/r/N289B82 * Checks if a file contains a Bitcoin Address using YARA rules. * Correlate results with Virus Total, Threat Crowd, and Hybrid Analysis. +## Upgrades: +* Train 3 model, and select the best model. (based on Accuracy and F1-score) +* When you train the model, this program choose major features in 15 features. It makes AI model more efficient. +* Chrome extension based on MLRD. (expected) + ## Install: ``` git clone https://github.com/callumlock/MLRD-Machine-Learning-Ransomware-Detection.git @@ -59,4 +64,4 @@ python3 mlrd.py -h ## WARNING: * Malicious programs are contained within this directory and thus, should be handled with care. * The ransomware distributed inside the test virual machine have been gathered for research purposes and are only for use within the scope of this project. -* Using these programs with malicious intent is strictly prohibited. +* Using these programs with malicious intent is strictly prohibited. diff --git a/classifier/best_model.pkl b/classifier/best_model.pkl new file mode 100644 index 0000000..57659db Binary files /dev/null and b/classifier/best_model.pkl differ diff --git a/mlrd.py b/mlrd.py index 8d3f2dd..f645c11 100644 --- a/mlrd.py +++ b/mlrd.py @@ -1,459 +1,112 @@ -''' - File name: mlrd.py - Author: Callum Lock - Date created: 31/03/2018 - Date last modified: 31/03/2018 - Python Version: 3.6 -''' - import os -import sys -import argparse -import array -import math -import pickle import pefile import hashlib -import yara -import pandas as pd -import numpy as np -from sklearn.externals import joblib -import urllib -import urllib3 -import json -import requests -from requests.auth import HTTPBasicAuth -from termcolor import colored, cprint -import colorama -import base64 -import webbrowser - +import string -# Class to extract features from input file. class ExtractFeatures(): - - # Defining init method taking parameter file. def __init__(self, file): - self.file = file + self.file = os.path.abspath(file) # 절대 경로 변환 + if not os.path.exists(self.file): + raise FileNotFoundError(f"❌ 파일 '{self.file}'을 찾을 수 없습니다.") - # Method for extracting the MD5 hash of a file. - # It is not always possible to fit the entire file into memory so chunks of - # 4096 bytes are read and sequentially fed into the function. - def get_md5(self, file): + def get_md5(self): + """ 파일의 MD5 해시를 계산 """ md5 = hashlib.md5() - with open(file, "rb") as f: + with open(self.file, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): md5.update(chunk) - return md5.hexdigest() - - # Method for compiling the yara rule for searching files for - # signs of bitcoin addresses. - def compile_bitcoin(self): - if not os.path.isdir("rules_compiled/Bitcoin"): - os.makedirs("rules_compiled/Bitcoin") - print("success") + return md5.hexdigest() - for n in os.listdir("rules/Bitcoin"): - rule = yara.compile("rules/Bitcoin/" + n) - rule.save("rules_compiled/Bitcoin/" + n) - - # Method for checking the input file for any signs of embedded bitcoin - # addresses. If the file does contain a bitcoin address a 1 is returned. - # Otherwise a 0 is returned. - def check_bitcoin(self, file): - for n in os.listdir("rules/Bitcoin"): - rule = yara.load("rules_compiled/Bitcoin/" + n) - m = rule.match(file) - if m: - return 1 - else: - return 0 - - # Method for extracting all features from an input file. - def get_fileinfo(self, file): - # Creates a dictionary that will hold feature names as keys and - # their feature values as values. - features = {} + def get_fileinfo(self): + """ PE 파일의 특징을 추출 """ + try: + pe = pefile.PE(self.file, fast_load=True) + except pefile.PEFormatError: + print(f"❌ Error: '{self.file}'은(는) 유효한 PE 파일이 아닙니다.") + return None - # Assigns pe to the input file. fast_load loads all directory - # information. - pe = pefile.PE(file, fast_load=True) + features = {} - # CPU that the file is intended for. + # PE 헤더 기반 정보 features['Machine'] = pe.FILE_HEADER.Machine - - # DebugSize is the size of the debug directory table. Clean files - # typically have a debug directory and thus, will have a non-zero - # values. - features['DebugSize'] = pe.OPTIONAL_HEADER.DATA_DIRECTORY[6].Size - - # Debug Relative Virtual Address (RVA). - features['DebugRVA'] = pe.OPTIONAL_HEADER.DATA_DIRECTORY[6].\ - VirtualAddress - - # MajorImageVersion is the version of the file. This is user defined - # and for clean programs is often populated. Malware often has a - # value of 0 for this. - features['MajorImageVersion'] = pe.OPTIONAL_HEADER.MajorImageVersion - - # MajorOSVersion is the major operating system required to run exe. - features['MajorOSVersion'] = pe.OPTIONAL_HEADER.\ - MajorOperatingSystemVersion - - # Export Relative Virtual Address (VRA). - features['ExportRVA'] = pe.OPTIONAL_HEADER.DATA_DIRECTORY[0].\ - VirtualAddress - - # ExportSize is the size of the export table. Usually non-zero for - # clean files. - features['ExportSize'] = pe.OPTIONAL_HEADER.DATA_DIRECTORY[0].Size - - # IatRVA is the relative virtual address of import address - # table. Clean files typically have 4096 for this where as malware - # often has 0 or a very large number. - features['IatVRA'] = pe.OPTIONAL_HEADER.DATA_DIRECTORY[12].\ - VirtualAddress - - # ResourcesSize is the size of resources section of PE header. - # Malware sometimes has 0 resources. - features['MajorLinkerVersion'] = pe.OPTIONAL_HEADER.\ - MajorLinkerVersion - - # MinorLinkerVersion is the minor version linker that produced the - # file. - features['MinorLinkerVersion'] = pe.OPTIONAL_HEADER.MinorLinkerVersion - - # NumberOfSections is the number of sections in file. - features['NumberOfSections'] = pe.FILE_HEADER.NumberOfSections - - # SizeOfStackReserve denotes the amount of virtual memory to reserve - # for the initial thread's stack. + features['NumberOfSections'] = pe.FILE_HEADER.NumberOfSections features['SizeOfStackReserve'] = pe.OPTIONAL_HEADER.SizeOfStackReserve - - # DllCharacteristics is a set of flags indicating under which - # circumstances a DLL's initialization function will be called. features['DllCharacteristics'] = pe.OPTIONAL_HEADER.DllCharacteristics - - # ResourceSize denotes the size of the resources section. - # Malware may often have no resources but clean files will. features['ResourceSize'] = pe.OPTIONAL_HEADER.DATA_DIRECTORY[2].Size - # Creates an object of Extract features and passes in the input - # file. The object get_bitcoin accesses the check_bitcoin method - # for which a 1 or 0 is returned and added as a value in the - # dictionary. - get_bitcoin = ExtractFeatures(file) - bitcoin_check = get_bitcoin.check_bitcoin(file) - features['BitcoinAddresses'] = bitcoin_check + # API 호출 분석 (IAT - Import Address Table) + api_calls = [] + if hasattr(pe, 'DIRECTORY_ENTRY_IMPORT'): + for entry in pe.DIRECTORY_ENTRY_IMPORT: + for imp in entry.imports: + if imp.name: + api_name = imp.name.decode(errors='ignore') + api_calls.append(api_name) - # Returns features for the given input file. - return features - -# Class to search third party reputation checkers and malware analysis -# websites to cross check if the tool is making correct decisions. -class RepChecker(): + features['API_Calls'] = len(api_calls) - # Init method to initalise api keys and base urls. - def __init__(self): - # Virus Total api key - vtapi = base64.b64decode('M2FlNzgwMDU5MTE3ZThkYzdmNjA5YjVlOWU1Y2JmOTRkMGJkNTA3NTAyNzI3NWJiOTM3YTg0NGEwYTYzNDNlYQ==') - self.vtapi = vtapi.decode('utf-8') - # Virus Total base URL - self.vtbase = 'https://www.virustotal.com/vtapi/v2/file/report' - self.http = urllib3.PoolManager() - # Threat Crowd base URL. - self.tcbase = 'http://www.threatcrowd.org/searchApi/v2/file/report/?resource=' - # Hybrid Analysis api key. - hapi = base64.b64decode('OGtzMDhrc3NrOGNja3Nnd3dnY2NnZzRzOG8wczA0Y2tzODA4c2NjYzAwZ2s0a2trZzRnc2s4Zzg0OGc4b2NvNA==') - self.hapi = hapi.decode('utf-8') - # Hybrid Analysis secret key. - hsecret = base64.b64decode('MTFhYjc1OTMxZGYzOWFjMmVjYmI3ZGNhNmI1MzYxMmE3YmU4ZjM3MTM5YTAwY2Nm') - self.hsecret = hsecret.decode('utf-8') - # Hybrid Analysis base URL. - self.hbase = 'https://www.hybrid-analysis.com/api/scan/' - - # Method for authenticating to Virus Total API and file information - # in JSON. - def get_virus_total(self, md5): - params = {'apikey': self.vtapi, 'resource':md5} - data = urllib.parse.urlencode(params).encode("utf-8") - r = requests.get(self.vtbase, params=params) - return r.json() + # 섹션별 엔트로피 분석 + for section in pe.sections: + name = section.Name.decode(errors='ignore').strip('\x00') + features[f"{name}_Size"] = section.SizeOfRawData + features[f"{name}_Entropy"] = section.get_entropy() - # Method for returning file information in JSON from - # Threat Crowd. - def get_threatcrowd(self, md5): - r = requests.get(self.tcbase) - return r.json() + # 문자열 분석 (랜섬웨어 관련 키워드 탐지) + features['Suspicious_Strings'] = self.check_malware_strings() - # Method for authenticating to Hybrid Analysis API and - # returning file information in JSON. - def get_hybrid(self, md5): - headers = {'User-Agent': 'Falcon'} - query = self.hbase + md5 - r = requests.get(query, headers=headers, auth=HTTPBasicAuth(self.hapi, self.hsecret)) - return r.json() - -# Open up survey to evaluate program. -def survey_mail(): - print('\n[*] Opening up survey in browser.\n') - webbrowser.open('https://www.surveymonkey.de/r/N289B82', new=2) - -# Function to parse user input. Takes in input file, extracted features, -# and parsed options. -def parse(file, features, display, virustotal, threatcrowd, hybridanalysis): - # Creates an object of RepChecker to return third party about - # input file. - get_data = RepChecker() - # Creates an object of ExtractFeatures to return information about the - # input file. - md5 = ExtractFeatures(file) - md5_hash = md5.get_md5(file) - - # If display option is selected, the extracted features are printed - # to the screen. - if display: - print("[*] Printing extracted file features...") - print("\n\tMD5: ", md5_hash) - print("\tDebug Size: ", features[0]) - print("\tDebug RVA: ", features[1]) - print("\tMajor Image Version:", features[2]) - print("\tMajor OS Version:", features[3]) - print("\tExport RVA:", features[4]) - print("\tExport Size:", features[5]) - print("\tIat RVA: ", features[6]) - print("\tMajor Linker Version: ", features[7]) - print("\tMinor Linker Version", features[8]) - print("\tNumber Of Sections: ", features[9]) - print("\tSize Of Stack Reserve: ", features[10]) - print("\tDll Characteristics: ", features[11]) - if features[12] == 1: - print("\tBitcoin Addresses: Yes\n") - else: - print("\tBitcoin Addresses: No\n") - - # If Virus Total option is selected, file information from Virus - # total is returned. - if virustotal: - print("[+] Running Virus Total reputation check...\n") - # Retrieves data from virus total. Searches by passing in - # md5 hash of input file. - data = get_data.get_virus_total(md5_hash) - - # If the response code is 0, error message is returned indicating - # that the md5 hash is not in virus total. Otherwise, the number - # of AV companies that detected the file as malicious is returned - # If 0, output is in green. - # Between 0 and 25, output is yellow. - # Over 25, output is red. - if data['response_code'] == 0: - print("[-] The file %s with MD5 hash %s was not found in Virus Total" % (os.path.basename(file), md5_hash)) - else: - print("\tResults for file %s with MD5 %s:" % (os.path.basename(file), md5_hash)) - if data['positives'] == 0: - print("\n\tDetected by: ", colored(str(data['positives']), 'green'), '/', data['total'], '\n') - elif data['positives'] > 0 and data['positives'] <= 25: - print("\n\tDetected by: ", colored(str(data['positives']), 'yellow'), '/', data['total'], '\n') - else: - print("\n\tDetected by: ", colored(str(data['positives']), 'red'), '/', data['total'], '\n') - - # Creates two lists to store the AV companies who detected the file - # as malicious and to store corresponding malware names. - av_firms = [] - malware_names = [] - fmt = '%-4s%-23s%s' - - # If any AV company indicated that the file is malicious, it is - # printed to the screen. - if data['positives'] > 0: - for scan in data['scans']: - if data['scans'][scan]['detected'] == True: - av_firms.append(scan) - malware_names.append(data['scans'][scan]['result']) - - print('\t', fmt % ('', 'AV Firm', 'Malware Name')) - for i, (l1, l2) in enumerate(zip(av_firms, malware_names)): - print('\t', fmt % (i, l1, l2)) - if data['permalink']: - print("\n\tVirus Total Report: ", data['permalink'], '\n') - - # Prints if Virus Total has found the file to be malicious. - if data['positives'] == 0: - print(colored('[*] ', 'green') + "Virus Total has found the file %s " % os.path.basename(file) + colored("not malicious.", 'green')) - if data['permalink']: - print("\n\tVirus Total Report: ", data['permalink'], '\n') - elif data['positives'] > 0 and data['positives'] <= 25: - print(colored('[*] ', 'red') + "Virus Total has found the file %s " % os.path.basename(file) + colored("has malicious properties.\n", 'yellow')) - else: - print(colored('[*] ', 'red') + "Virus Total has found the file %s " % os.path.basename(file) + colored("is malicious.\n", 'red')) - - # If threat crowd option is selected, file information is returned. - if threatcrowd: - fmt = '%-4s%-23s' - print("[+] Retrieving information from Threat Crowd...\n") - data = get_data.get_threatcrowd(md5_hash) - - # If response code is 0, an error message is thrown to indicate - # the file is not in Threat Crowd. Otherwise, the SHA1 Hash, - # domain names, and malware names given by AV companies for - # the file is printed to the screen. - if data['response_code'] == "0": - print("[-] The file %s with MD5 hash %s was not found in Threat Crowd.\n" % (os.path.basename(file), md5_hash)) - else: - print("\n\tSHA1: ", data['sha1']) - if data['ips']: - print('\n\t', fmt % ('', 'IPs')) - for i, ip in enumerate((data['ips'])): - print('\t', fmt % (i+1, ip)) - - if data['domains']: - print('\n\t', fmt % ('', 'Domains')) - for i, domain in enumerate((data['domains'])): - print('\t', fmt % (i+1, domain)) - - if data['scans']: - if data['scans'][1:]: - print('\n\t', fmt % ('', 'Antivirus')) - for i, scan in enumerate(data['scans'][1:]): - print('\t', fmt % (i+1, scan)) - - print('\n\tThreat Crowd Report: ', data['permalink'], '\n') - - # If hybrid analysis option is selected, file information is returned. - if hybridanalysis: - # Searches hybrid analysis with md5 hash of file and attempts - # to return its information in JSON format. - data = get_data.get_hybrid(md5_hash) - fmt = '%-4s%-23s' - - print("[+] Retrieving information from Hybrid Analysis...\n") - - # If no response, error message is thrown to indicate that the file - # is not in Hybrid Analysis. Otherwise, SHA256, SHA1, Threat Level, - # Threat Score, Verdict (malicious / not malicious), malware family, - # and network information is returned - if not data['response']: - print("[-] The file %s with MD5 hash %s was not found in Hybrid Analysis." % (os.path.basename(file), md5_hash), '\n') - else: - try: - print('\t', data['response'][0]['submitname']) - except: - pass - - print('\tSHA256:', data['response'][0]['sha256']) - print('\tSHA1: ', data['response'][0]['sha1']) - print('\tThreat Level: ', data['response'][0]['threatlevel']) - print('\tThreat Score: ', data['response'][0]['threatscore']) - print('\tVerdict: ', data['response'][0]['verdict']) - - try: - print('\tFamily: ', data['response'][0]['vxfamily']) - except: - pass - try: - if data['response'][0]['classification_tags']: - print('\n\t', fmt % ('', 'Class Tags')) - for i, tag in enumerate(data['response'][0]['classification_tags']): - print('\t', fmt % (i+1, tag)) - else: - print("\tClass Tags: No Classification Tags.") - except: - pass - try: - if data['response'][0]['compromised_hosts']: - print('\n\t', fmt % ('', 'Compromised Hosts')) - for i, host in enumerate(data['response'][0]['compromised_hosts']): - print('\t', fmt % (i+1, host)) - else: - print('\t\nCompromised Hosts: No Compromised Hosts.') - except: - pass - try: - if data['response'][0]['domains']: - print('\n\t', fmt % ('', 'Domains')) - for i, domain in enumerate(data['response'][0]['domains']): - print('\t', fmt % (i+1, domain)) - else: - print('\tDomains: No Domains.') - except: - pass - try: - if data['response'][0]['total_network_connections']: - print('\tNetwork Connections: ', data['response'][0]['total_network_connections']) - else: - print('\n\tNetwork Connections: No Network Connections') - except: - pass - try: - if data['response'][0]['families']: - print('\tFamilies: ', data['response'][0]['families']) - except: - pass + return features - # Verdict is printed to screen. - # Malicious = red. - # Benign = green. - if data['response'][0]['verdict'] == "malicious": - print(colored('\n[*] ', 'red') + "Hybrid Analysis has found that the file %s " % os.path.basename(file) + colored("is malicious.\n", 'red')) + def check_malware_strings(self): + """ 파일 내부의 문자열에서 랜섬웨어 관련 키워드 탐지 """ + suspicious_keywords = ["ransom", "decrypt", "bitcoin", "AES", "locker", "hacker"] + strings = self.extract_strings() + + # 🔹 keyword 변수를 사용하지 않고 올바르게 수정 + return any(k in s.lower() for s in strings for k in suspicious_keywords) + + def extract_strings(self, min_length=4): + """ PE 파일에서 ASCII 문자열 추출 """ + with open(self.file, "rb") as f: + data = f.read() + + result = [] + printable_chars = set(bytes(string.printable, 'ascii')) + temp = [] + for byte in data: + if byte in printable_chars: + temp.append(chr(byte)) else: - print(colored('\n[*] ', 'green') + "Hybrid Analysis has found that the file %s " % os.path.basename(file) + colored("is not malicious.\n", 'green')) + if len(temp) >= min_length: + result.append("".join(temp)) + temp = [] + return result -def main(): - parser = argparse.ArgumentParser(epilog="MLRD uses machine learning to detect ransomware\n\ - . Supply a file to determine whether or not it is ransomware. Virus Total\ - , Threat Crowd and Hybrid Analysis can be queried for verification.", - description="Machine Learning Ransowmare Detector (MLRD)") +def analyze_pe_file(file_path): + """ PE 파일을 분석하여 랜섬웨어 가능성을 판별 """ + file_path = os.path.abspath(file_path) # 절대 경로 변환 - parser.add_argument('file', nargs='?', help="File To Parse", ) - parser.add_argument('-d', '--displayfeatures', action='store_true', dest='display', help='Display extracted file features.') - parser.add_argument('-v', "--virustotal", action='store_true', dest='virustotal', help="Run with Virus Total check.") - parser.add_argument('-t', '--threatcrowd', action='store_true', dest='threatcrowd', help="Run with Threat Crowd check.") - parser.add_argument('-z', '--hybridanalysis', action='store_true', dest='hybridanalysis', help="Run Hybrid Analysis check.") - parser.add_argument('-s', '--survey', nargs='*', help='Evaluate Program using Survey.') + if not os.path.exists(file_path): + print(f"❌ Error: 파일 '{file_path}'이 존재하지 않습니다.") + return - args = parser.parse_args() + print(f"🔍 Analyzing {file_path} ...") - colorama.init() - - if args.survey is not None: - survey_mail() - sys.exit(0) - - # Loads classifier - clf = joblib.load(os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'classifier/classifier.pkl')) - - # Loads saved features - features = pickle.loads(open(os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'classifier/features.pkl'), - 'rb').read()) - - # Creates an object of ExtractFeatures and passes in input file. - get_features = ExtractFeatures(args.file) - - # Assigns data to extracted features - data = get_features.get_fileinfo(args.file) - - feature_list = list(map(lambda x:data[x], features)) - - print("\n[+] Running analyzer...\n") + extractor = ExtractFeatures(file_path) + features = extractor.get_fileinfo() + + if features is None: + return - # Asssings result as the prediction of the input file based on its given features. - result = clf.predict([feature_list])[0] + print("\n=== 📊 PE File Analysis Result ===") + for key, value in features.items(): + print(f" {key}: {value}") - # If result is 1, the file is benign. - # Otherwise, the file is malicious. - if result == 1: - print(colored('[*] ', 'green') + "The file %s has been identified as " % os.path.basename(sys.argv[1]) + colored('benign.\n', 'green')) + if features["Suspicious_Strings"]: + print("\n⚠️ 악성 문자열이 포함되어 있음! 랜섬웨어 가능성이 높음!") else: - print(colored('[*] ', 'red') + "The file %s has been identified as " % os.path.basename(sys.argv[1]) + colored('malicious.\n', 'red')) - - # Passes command line arguments to parse function for parsing. - if args.display or args.virustotal or args.threatcrowd or args.hybridanalysis: - parse(args.file, feature_list, args.display, args.virustotal, args.threatcrowd, args.hybridanalysis) + print("\n✅ 악성 문자열이 발견되지 않음.") -if __name__ == '__main__': - main() +# 실행 예제 +file_path = r"Test Data\Benign Test Data\setup_wm.exe" +analyze_pe_file(file_path) diff --git a/mlrd_learn.py b/mlrd_learn.py index 525395b..5d7bc60 100644 --- a/mlrd_learn.py +++ b/mlrd_learn.py @@ -1,66 +1,73 @@ -''' - File name: mlrd_learn.py - Author: Callum Lock - Date created: 31/03/2018 - Date last modified: 31/03/2018 - Python Version: 3.6 -''' import pandas as pd import numpy as np -import pickle +import xgboost as xgb +import lightgbm as lgb from sklearn import model_selection -import sklearn.ensemble as ske -import sklearn.metrics -from sklearn.metrics import f1_score -from sklearn.externals import joblib +from sklearn.metrics import f1_score, accuracy_score +import joblib +from sklearn.ensemble import RandomForestClassifier +from sklearn.feature_selection import SelectFromModel -# Main code function that trains the random forest algorithm on dataset. def main(): - print('\n[+] Training MLRD using Random Forest Algorithm...') + print('\n[+] Upgrading MLRD Machine Learning Model...') - # Creates pandas dataframe and reads in csv file. - df = pd.read_csv('data_file.csv', sep=',') + # 데이터셋 로드 (예외 처리 추가) + try: + df = pd.read_csv('data_file.csv', sep=',') + except FileNotFoundError: + print("❌ Error: 'data_file.csv' 파일을 찾을 수 없습니다. 데이터셋을 준비하세요.") + return - # Drops FileName, md5Hash and Label from data. - X = df.drop(['FileName', 'md5Hash', 'Benign'], axis=1).values + # 데이터셋 분포 확인 (추가된 코드) + print("\n[+] Dataset Overview:") + print(df['Benign'].value_counts()) - # Assigns y to label + # 특징 선택 및 데이터 분할 + X = df.drop(['FileName', 'md5Hash', 'Benign'], axis=1).values y = df['Benign'].values - - # Splitting data into training and test data X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=42) - # Print the number of training and testing samples. - print("\n\t[*] Training samples: ", len(X_train)) - print("\t[*] Testing samples: ", len(X_test)) + print("\n[*] Training samples:", len(X_train)) + print("[*] Testing samples:", len(X_test)) - # Train Random forest algorithm on training dataset. - clf = ske.RandomForestClassifier(n_estimators=50) - clf.fit(X_train, y_train) + # 모델 1: 랜덤 포레스트 (데이터 불균형 해결 추가) + rf_clf = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced') + rf_clf.fit(X_train, y_train) + rf_score = accuracy_score(y_test, rf_clf.predict(X_test)) + rf_f1 = f1_score(y_test, rf_clf.predict(X_test)) - # Perform cross validation and print out accuracy. - score = model_selection.cross_val_score(clf, X_test, y_test, cv=10) - print("\n\t[*] Cross Validation Score: ", round(score.mean()*100, 2), '%') + # 모델 2: XGBoost + xgb_clf = xgb.XGBClassifier(n_estimators=100, learning_rate=0.1, max_depth=6, random_state=42) + xgb_clf.fit(X_train, y_train) + xgb_score = accuracy_score(y_test, xgb_clf.predict(X_test)) + xgb_f1 = f1_score(y_test, xgb_clf.predict(X_test)) - # Calculate f1 score. - y_train_pred = model_selection.cross_val_predict(clf, X_train, y_train, cv=3) - f = f1_score(y_train, y_train_pred) - print("\t[*] F1 Score: ", round(f*100, 2), '%') + # 모델 3: LightGBM + lgb_clf = lgb.LGBMClassifier(n_estimators=100, learning_rate=0.1, max_depth=6, random_state=42) + lgb_clf.fit(X_train, y_train) + lgb_score = accuracy_score(y_test, lgb_clf.predict(X_test)) + lgb_f1 = f1_score(y_test, lgb_clf.predict(X_test)) - # Save the configuration of the classifier and features as a pickle file. - all_features = X.shape[1] - features = [] + print("\n[*] Model Performance:") + print(f" - RandomForest Accuracy: {rf_score*100:.2f}%, F1-score: {rf_f1:.2f}") + print(f" - XGBoost Accuracy: {xgb_score*100:.2f}%, F1-score: {xgb_f1:.2f}") + print(f" - LightGBM Accuracy: {lgb_score*100:.2f}%, F1-score: {lgb_f1:.2f}") - for feature in range(all_features): - features.append(df.columns[2+feature]) + # 🔹 최적 모델 자동 선택 + best_model = max( + [(rf_f1, rf_clf), (xgb_f1, xgb_clf), (lgb_f1, lgb_clf)], key=lambda x: x[0] + )[1] - try: - print("\n[+] Saving algorithm and feature list in classifier directory...") - joblib.dump(clf, 'classifier/classifier.pkl') - open('classifier/features.pkl', 'wb').write(pickle.dumps(features)) - print("\n[*] Saved.") - except: - print('\n[-] Error: Algorithm and feature list not saved correctly.\n') + print("\n[+] Selecting best model...") + joblib.dump(best_model, 'classifier/best_model.pkl') + print("[*] Model saved successfully.") + + # 🔹 특징 선택 수행 (Feature Selection) + print("\n[+] Performing feature selection...") + feature_selector = SelectFromModel(best_model, threshold="median", prefit=True) + X_selected = feature_selector.transform(X) + + print(f"[*] Reduced features from {X.shape[1]} to {X_selected.shape[1]}.") if __name__ == '__main__': main()