From 844a1355af45c18e930327937e0a2473ef1c4f10 Mon Sep 17 00:00:00 2001 From: Hyunji Moon <30194633+hyunjimoon@users.noreply.github.com> Date: Thu, 6 Jan 2022 20:21:59 -0500 Subject: [PATCH 01/23] commit for pr --- graph_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_search.py b/graph_search.py index 20a26de..3da349a 100644 --- a/graph_search.py +++ b/graph_search.py @@ -4,7 +4,7 @@ import json from queue import PriorityQueue -# DEBUG_IO = True +## DEBUG_IO = True DEBUG_IO = False def text_command(args): From 1ae4889e308cf38c2d338274c2c37054f1c61af7 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 14 Jan 2022 00:36:50 +0900 Subject: [PATCH 02/23] Find max hierarchy models for each subtree --- lib/ModuleTree.hs | 67 +++++++++++++++++++++++++++++++++++++++++++++-- mstan/CLI.hs | 12 +++++++++ mstan/Main.hs | 15 +++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/lib/ModuleTree.hs b/lib/ModuleTree.hs index 08cf5a4..a2b8499 100644 --- a/lib/ModuleTree.hs +++ b/lib/ModuleTree.hs @@ -5,6 +5,8 @@ module ModuleTree ( growTree , printModularTree + , implSigs + , findHighestModels , moduleTreeGraphviz , ModuleBranch (..) , Node (..) @@ -18,10 +20,10 @@ import qualified Data.Map as Map import Data.Maybe import Data.Set (Set) import qualified Data.Set as Set -import Data.Text (Text) +import Data.Text (Text, intercalate) import qualified Data.Text.IO as Text import Data.Fix (refold) - +import Control.Applicative (liftA2) import Graphviz import Types import Indent @@ -54,6 +56,7 @@ implSigs p = Map.insert Root (moduleSigs (topProgram p)) orderSigs :: Set SigName -> [SigName] orderSigs set = filter (`Set.member` set) . map snd $ signatures p + -- Pattern functor for module tree data ModuleBranch f = SigBranch SigName [f] | ImplBranch Selection [f] deriving (Eq, Ord, Functor) @@ -74,6 +77,66 @@ idToSel :: ImplID -> Selection idToSel Root = Map.empty idToSel (ImplID sig impl) = Map.singleton sig impl + + +getFullImplString :: ImplID -> Text +getFullImplString (ImplID parent name) = (unSigName parent) <> ":" <> (unImplName name) +getFullImplString Root = "" + +-- for implementations: +-- if terminal: return self +-- else: return merge(self, union(left_signature, right_signature)) +hierTraverseNode :: Node -> Map ImplID [SigName] -> Map SigName [ImplName] -> [Text] +hierTraverseNode (Impl impl) impl2SigMap sig2ImplMap = do + let signames = Map.lookup impl impl2SigMap + case signames of + Just signames -> if length signames == 1 + then do + -- If there's only 1 signature underneath the implementation, nothing more needs to be done + map (\x -> Data.Text.intercalate "," ((getFullImplString impl) : (hierTraverseNode (Sig x) impl2SigMap sig2ImplMap))) (signames) + else + ( + if length signames == 2 + then do + -- If there's 2 signatures underneath the implementation, you need to do an equivalent of a graph join + -- Where you take the combinations of each signature's top nodes + let first_sig = (hierTraverseNode (Sig $ head signames) impl2SigMap sig2ImplMap); second_sig = (hierTraverseNode (Sig $ last signames) impl2SigMap sig2ImplMap) in + map (\(x, y) -> Data.Text.intercalate "," [getFullImplString impl, x, y]) (liftA2 (,) first_sig second_sig) + else [getFullImplString impl] + ) + Nothing -> [getFullImplString impl] + +-- for signatures: find implementations that don't have the value 'no' +hierTraverseNode (Sig sig) impl2SigMap sig2ImplMap = do + let impls = Map.lookup sig sig2ImplMap + case impls of + Just impls -> concat $ map (\x -> if (unImplName x) /= "no" then hierTraverseNode (Impl ImplID {parent=sig, name=x}) impl2SigMap sig2ImplMap else []) impls + Nothing -> [] + + + +-- entry point is implementation 'root'. From there, we check if the number of +-- signatures are > 1, which is when the actual signature substrees start +hierTraverseInit :: Node -> Map ImplID [SigName] -> Map SigName [ImplName] -> [Text] +hierTraverseInit (Impl x) implToSigMap sigToImplMap = do + let signames = Map.lookup x implToSigMap + case signames of + Just signames -> if length signames >= 2 + then do + concat (map (\m -> hierTraverseNode (Sig m) implToSigMap sigToImplMap) signames) + else hierTraverseInit (Sig $ head $ signames) implToSigMap sigToImplMap + Nothing -> [] + +-- signatures cannot branch off into more than 1 +hierTraverseInit (Sig x) implToSigMap sigToImplMap = do + let implnames = Map.lookup x sigToImplMap + case implnames of + Just implnames -> hierTraverseInit (Impl ImplID {parent=x, name=head implnames}) implToSigMap sigToImplMap + Nothing -> [] + +findHighestModels :: ModularProgram -> [Text] +findHighestModels m = hierTraverseInit (Impl Root) (implSigs m) (sigImpls m) + ------ -- Visualizations ------ diff --git a/mstan/CLI.hs b/mstan/CLI.hs index 2fb4fa4..8b2dd1c 100644 --- a/mstan/CLI.hs +++ b/mstan/CLI.hs @@ -26,6 +26,8 @@ data ExecCommand = | GetModelGraph | GetModuleGraph | GetAllModels + | GetHighestModels + | GetImplMap deriving Show parseOptions :: IO RunOptions @@ -94,4 +96,14 @@ parserExecCommand = hsubparser (info (pure GetAllModels) (progDesc "Return all model IDs") ) + <> OptParse.command + "impl-map" + (info (pure GetImplMap) + (progDesc "Print implementation -> signature map") + ) + <> OptParse.command + "get-highest-models" + (info (pure GetHighestModels) + (progDesc "Return models of highest hierarchy by complexity, per subtree") + ) ) diff --git a/mstan/Main.hs b/mstan/Main.hs index 319aca7..fb1423b 100644 --- a/mstan/Main.hs +++ b/mstan/Main.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE OverloadedStrings #-} module Main where import Data.Text ( Text ) import qualified Data.Text as Text import qualified Data.Text.IO as Text import qualified Data.Set as Set +import Data.Map (Map) +import qualified Data.Map as Map import Control.Monad import Parsing @@ -22,6 +25,10 @@ main = do options <- parseOptions execOptions options +stringfyImplID :: ImplID -> Text +stringfyImplID Root = "(root)" +stringfyImplID x = "(" <> unSigName (parent x) <> " " <> unImplName (name x) <> ")" + execOptions :: RunOptions -> IO () execOptions (RunOptions file debugParse maybeOutFile command) = do program <- Parsing.readModularProgram file @@ -43,6 +50,14 @@ execCommand prog (GetNeighbors selection) = return $ map showSelection . Set.toList $ modelNeighbors prog selection execCommand prog (GetConcrete selection) = return $ linesConcreteProgram $ selectModules prog selection + +execCommand prog GetImplMap = do + return (map (\(impl, sigs) -> Text.intercalate "|" ((stringfyImplID impl) : map unSigName sigs)) (Map.toList (implSigs prog))) + +execCommand prog GetHighestModels = do + let highestModels = findHighestModels prog + return highestModels + execCommand prog GetMinimumSelection = return $ [showSelection $ firstSelection prog] execCommand prog GetModelGraph = do From 7cc22e8cf7a4e22936b41fdb794a45959adc318e Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 14 Jan 2022 00:40:39 +0900 Subject: [PATCH 03/23] Update n(signature)=1 case --- lib/ModuleTree.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/ModuleTree.hs b/lib/ModuleTree.hs index a2b8499..3cc243b 100644 --- a/lib/ModuleTree.hs +++ b/lib/ModuleTree.hs @@ -93,7 +93,7 @@ hierTraverseNode (Impl impl) impl2SigMap sig2ImplMap = do Just signames -> if length signames == 1 then do -- If there's only 1 signature underneath the implementation, nothing more needs to be done - map (\x -> Data.Text.intercalate "," ((getFullImplString impl) : (hierTraverseNode (Sig x) impl2SigMap sig2ImplMap))) (signames) + map (\x -> Data.Text.intercalate "," [getFullImplString impl, x]) (hierTraverseNode (Sig $ head signames) impl2SigMap sig2ImplMap) else ( if length signames == 2 From 70233bd5d7981e6a5955068391ae4285984ad75e Mon Sep 17 00:00:00 2001 From: Dashadower Date: Tue, 18 Jan 2022 12:02:04 +0900 Subject: [PATCH 04/23] Update highest model command and add apex search --- apex-predator-search.py | 46 +++++++++++++++++++++++++++++++++++++++++ mstan/Main.hs | 7 +++++-- 2 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 apex-predator-search.py diff --git a/apex-predator-search.py b/apex-predator-search.py new file mode 100644 index 0000000..627f1f5 --- /dev/null +++ b/apex-predator-search.py @@ -0,0 +1,46 @@ +import subprocess + +def text_command(args): + """Run a shell command, return its stdout as a String or throw an exception if it fails.""" + + try: + result = subprocess.run(args, text=True, check=True, + stderr=subprocess.STDOUT, stdout=subprocess.PIPE) + + stdout = result.stdout.strip() + return stdout + except subprocess.CalledProcessError as exc: + sys.exit("Error in `mstan`: \"" + exc.output.strip() + "\"") + + +class ModelEvaluator: + def __init__(self, dataFile): + self.dataFile = dataFile + + def score(self, modelPath): + """Return the numerical score for the Stan program at the given filepath""" + stdout_result = text_command(["Rscript", "elpd.R", modelPath, self.dataFile]) + return float(stdout_result.split('\n')[-1].strip()) + +model_file_name = "examples/birthday/birthday.m.stan" + +args = ["mstan", "-f", model_file_name, "get-highest-models"] + +result = subprocess.run(args, text=True, check=True, + stderr=subprocess.STDOUT, stdout=subprocess.PIPE) + +stdout = result.stdout.strip().split("\n") + +results = {} + +for model in stdout: + model_code_args = ["mstan", "-f", model_file_name, "concrete-model", "-s", model,] + print(model_code_args) + model_code = subprocess.run(model_code_args, text=True, check=True, stderr=subprocess.STDOUT, stdout=subprocess.PIPE).stdout.strip() + with open("temp_stanmodel.stan", "w") as f: + f.write(model_code) + + result = ModelEvaluator("examples/birthday/births_usa_1969.json").score("temp_stanmodel.stan") + results[model] = result + +print(results) diff --git a/mstan/Main.hs b/mstan/Main.hs index fb1423b..a7eca1b 100644 --- a/mstan/Main.hs +++ b/mstan/Main.hs @@ -55,8 +55,11 @@ execCommand prog GetImplMap = do return (map (\(impl, sigs) -> Text.intercalate "|" ((stringfyImplID impl) : map unSigName sigs)) (Map.toList (implSigs prog))) execCommand prog GetHighestModels = do - let highestModels = findHighestModels prog - return highestModels + -- let highestModels = findHighestModels prog + -- return highestModels + let filteredProg = prog { implementations = filter ((/= ImplName "no") . implName) (implementations prog) } + let sels = allSelections filteredProg + return . map showSelection . Set.toList $ sels execCommand prog GetMinimumSelection = return $ [showSelection $ firstSelection prog] From 73a450d03069140c4fb0d4b97583fef0d07710fc Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 21 Jan 2022 10:36:40 +0900 Subject: [PATCH 05/23] Add chain search --- Chain_Generation_and_Search.py | 163 +++++++++++++++++++++++++++++++++ apex-predator-search.py | 17 ++++ 2 files changed, 180 insertions(+) create mode 100644 Chain_Generation_and_Search.py diff --git a/Chain_Generation_and_Search.py b/Chain_Generation_and_Search.py new file mode 100644 index 0000000..728dd3e --- /dev/null +++ b/Chain_Generation_and_Search.py @@ -0,0 +1,163 @@ +import numpy as np +import random + +# Top level Signature Hiearchy is a list +# The ith element of a Hiearchy is a list which contains different implementations that are sorted in a decreasing order of implementation hierarchy +Top_level_Signature_Hierarchy = [ + ["DayOfWeekTrend:yes,DayOfWeekWeights:weighted","DayOfWeekTrend:yes,DayOfWeekWeights:uniform","DayOfWeekTrend:no"], + ["DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no","DayOfYearTrend:no"], + ["HolidayTrend:yes","HolidayTrend:yes"], + ["LongTermTrend:yes","LongTermTrend:no"], + ["SeasonalTrend:yes","SeasonalTrend:no"] +] + +# Generate a chain given the Top_level_Signature_Hierarchy information +# The following function returns a chain of models where for any i, the ith model's model complexity is strictly higher than that of the (i+1)th model +# In other words, the models are sorted in a decreasing order of model complexity along the chain +# (1) Chain is a list (2) Each element of the Chain is a model (type: list) (3) The ith element of a model is the 'implementation' of the model for the ith top-level signature +def Chain_Generation(Top_level_Signature_Hierarchy): + Chain = [] + # m: nubmer of top-level signatures + m = len(Top_level_Signature_Hierarchy) + # Create the apex predator model of the chain by choosing the implementations with highest hierarchies for all the top level signatures + # Apex_predator = [top_level_signature_implementations[0] for top_level_signature_implementations in Top_level_Signature_Hierarchy] + # Add the apex predator model to the Chain + # Chain.append(Apex_predator) + # create a list that checks the indices of the implementations of the most recent module (added to the chain) in the top-level signature implementation lists + Cur_indices = [0 for i in range(m)] + Cur_Indices_sum = sum(Cur_indices) + # The following list represents the possible increments for the indices + Possible_increments_for_indices = [ len(Top_level_Signature_Hierarchy[i]) - Cur_indices[i] for i in range(m)] + # Candidates for increments (the indices of top-level signatures whose implementations' indices can be increased) + Candidates_for_increment = [] + for i in range(m): + if Possible_increments_for_indices[i] > 0: + Candidates_for_increment.append(i) + # The maximum value for the sum of indices + Indices_sum_UB = np.sum([len(implementations)-1 for implementations in Top_level_Signature_Hierarchy]) + # Create the next model for the chain by randomly increasing the index of exactly one implementation of the previous model wherever possible + # If the Chain contains the model with the lowest possible complexity (in which case the Cur_Indices_sum equals the Indices_sum_UB, terminate the Chain generation algorithm) + while Cur_Indices_sum <= Indices_sum_UB: + # Get the current iteration's model based on the Current indices & Top-level signature hierarchy + cur_iter_model = [Top_level_Signature_Hierarchy[Cur_indices[i]] for i in range(m)] + Chain.append(cur_iter_model) + # Randomly increase the index of a particular implementation + increment_ind = random.choice(Candidates_for_increment) + # Increment the index of the implementation and update Cur_Indices_sum, Cur_indices, Possible_increments_for_indices, Candidates_for_increment + Cur_Indices_sum +=1 + Cur_indices[increment_ind] = Cur_indices[increment_ind]+1 + Possible_increments_for_indices[increment_ind] = Possible_increments_for_indices[increment_ind]-1 + if Possible_increments_for_indices[increment_ind] == 0: + Candidates_for_increment.remove(increment_ind) + return Chain + +# Compute the ELPD value of a model +# model is a list where its ith element is a string that represents the implementation for the ith top-level signature +def ELPD(model): + # use the elements of 'model' (type: list) to obtain the full 'name' of the model + # Then use 'STAN' to compute ELPD of the model (based on the 'model name' obtained above) + return 0 + + +# Chain is a list whose elements are individual models. Each model is a list that consists of the implementations of the top-level signatures. +# K: a parameter that represents the maximum number of models that their ELPD values can be computed (the value of the parameter is dependent on the computational resource, total number of chains to be searched, etc.) +# alpha: parameter for controlling the range of the stepsize in each iteration +# The following function conducts a dynamic search along the chain +# and finally returns the model with the highest ELPD value and its ELPD value (among those that the ELPD values were computed) +# Suppose that the chain is given as an input and the models are sorted in a decreasing order of model complexity. +# i.e. the 1st model in the chain has the highest model complexity and the last model has the lowest complexity. +def Chain_Search(Chain,K,alpha=0.5): + n = len(Chain) + # if number of models in the chain is smaller than or equal to K, we can compute ELPD values of each model and choose the one with the highest value + if n <=K: + ELPD_values = [ELPD(model) for model in Chain] + highest_ELPD_model_ind = np.argmax(ELPD_values) + best_model, best_ELPD_val = Chain[highest_ELPD_model_ind], ELPD_values[highest_ELPD_model_ind] + return best_model, best_ELPD_val + # if number of models in the chain is strictly larger than K, then there is no choice but to conduct a search + else: + cur_ind = 0 + num_ELPD_computed = 0 + ELPD_computed_model_indices = [] + ELPD_values_obtained = [] + while num_ELPD_computed < K and cur_ind < n: + # compute the ELPD value of the current iteration's model + cur_iter_ELPD = ELPD(Chain[cur_ind]) + # update the ELPD compute model indices, ELPD values, the nubmer of ELPD values computed obtained respectively + ELPD_computed_model_indices.append(cur_ind) + ELPD_values_obtained.append(cur_iter_ELPD) + num_ELPD_computed+=1 + step_size = 1 # set default step size as 1 + # if it is neither the 1st nor the 2nd iteration, the step size should be modified. + if not (num_ELPD_computed ==0 or num_ELPD_computed==1): + step_size_Uniform = (n-1-cur_ind)/(K-num_ELPD_computed) + step_size_LB = step_size_Uniform*(1-alpha) + step_size_UB = step_size_Uniform*(1+alpha) + ELPD_slope_cur_iter = abs((ELPD(ELPD_computed_model_indices[-1])-ELPD(ELPD_computed_model_indices[-2])))/(ELPD_computed_model_indices[-1]-ELPD_computed_model_indices[-2]) + ELPD_slope_previous_iter = abs((ELPD(ELPD_computed_model_indices[-2])-ELPD(ELPD_computed_model_indices[-3])))/(ELPD_computed_model_indices[-2]-ELPD_computed_model_indices[-3]) + step_size_candidate = step_size_Uniform*(ELPD_slope_previous_iter/ELPD_slope_cur_iter) + step_size = min(step_size_UB,min(step_size_LB,step_size_candidate)) + # update the current index (which would the index of the model in the next iteration) + cur_iter_ELPD += step_size + # find the best model (the model with the highest ELPD value) among the models whose ELPD values were computed. + highest_ELPD_value_model_chain_search_ind = np.argmax(ELPD_values_obtained) + Final_best_ELPD_model_ind, Final_best_ELPD_val = ELPD_computed_model_indices[highest_ELPD_value_model_chain_search_ind], ELPD_values_obtained[highest_ELPD_value_model_chain_search_ind] + Final_best_model = Chain[Final_best_ELPD_model_ind] + return Final_best_model, Final_best_ELPD_val + + + + +# Original Chain Generation Algorithm (Proposed earlier) + +# Simple example +Model_collection =[ + 'Mean:normal,Stddev:lognormal,StddevInformative:no', + 'Mean:normal,Stddev:lognormal,StddevInformative:yes', + 'Mean:normal,Stddev:standard', + 'Mean:standard,Stddev:lognormal,StddevInformative:no', + 'Mean:standard,Stddev:lognormal,StddevInformative:yes', + 'Mean:standard,Stddev:standard' +] + + +# Simple example +Current_Node = ['normal','lognormal,StddevInformative:yes'] +Module_names = ['Mean','Stddev'] +Hierarchy = [{'normal':['standard']}, + {'lognormal,StddevInformative:yes':['standard'],'lognormal,StddevInformative:no':['standard']} + ] + +# Birthday Case study +Current_Node = ['yes,DayOfWeekWeights:weighted','yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:yes','yes','yes','yes'] +Module_names = ['DayOfWeekTrend','DayofYearTrend','HolidayTrend','LongTermTrend','SeasonalTrend'] +Hierarchy = [{'yes,DayOfWeekWeights:weighted':['no'],'yes,DayOfWeekWeights:uniform':['no']}, + {'yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:yes':['no'], + 'yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:no':['no'], + 'yes,DayOfYearHierarchicalVariance:no,DayOfYearNormalVariance:yes':['no'], + 'yes,DayOfYearHierarchicalVariance:no,DayOfYearNormalVariance:no':['no']}, + {'yes':['no']}, + {'yes': ['no']}, + {'yes': ['no']} +] + +def Generate_Maximal_Chain(Current_Node,Module_names,Hierarchy): + num_Modules = len(Current_Node) + Chain = [Current_Node] + Chain_length = 1 + Longest_Chain = [Current_Node] + Longest_Chain_length = 1 + for i in range(num_Modules): + if Current_Node[i] in Hierarchy[i]: + for alternative in Hierarchy[i][Current_Node[i]]: + Next_node = [Current_Node[j] if i !=j else alternative for j in range(num_Modules)] + Additional_Chain,Additional_Chain_length = Generate_Maximal_Chain(Next_node,Module_names, Hierarchy) + if Chain_length+Additional_Chain_length > Longest_Chain_length: + Longest_Chain = Chain + Additional_Chain + Longest_Chain_length = Chain_length+Additional_Chain_length + return Longest_Chain, Longest_Chain_length + +Longest_Chain, Longest_Chain_length = Generate_Maximal_Chain(Current_Node,Module_names,Hierarchy) +print("Longest Chain: ",[[Module_names[k]+":"+node[k] for k in range(len(node))] for node in Longest_Chain]) +print("Longest Chain Length: ",Longest_Chain_length) + diff --git a/apex-predator-search.py b/apex-predator-search.py index 627f1f5..5459f6c 100644 --- a/apex-predator-search.py +++ b/apex-predator-search.py @@ -44,3 +44,20 @@ def score(self, modelPath): results[model] = result print(results) + + + + +hierarchy_info = [ + ["DayofWeekTrend:Yes,DayofWeekWeights:weighted", "DayofWeekTrend:Yes,DayofWeekWeights:uniform", "DayofWeekTrend:no"], + ["DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:yes","DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:no","DayofYearTrend:yes,DayofHierarchicalVariance:no,DayofYearNormalVariance:yes", "DayofYearTrend:no"] + ["HolidayTrend:Yes", "HolidayTrend:No"], + #... +] # n, n-1, ... 1 + +current_model = ["DayofWeek:Yes", "HolidayTrend:Yes"] + +chain = [] +chain.append(",".join(current_model)) +current_model[0] = hierarchy_info[0][1] +chain.append(",".join(current_model)) \ No newline at end of file From 67fedf0d67a9022f0a95c75fe7951d5dbbb7ca73 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 21 Jan 2022 11:07:25 +0900 Subject: [PATCH 06/23] Update chain search --- Chain_Generation_and_Search.py | 176 ++++++++++++++++++++------------- 1 file changed, 110 insertions(+), 66 deletions(-) diff --git a/Chain_Generation_and_Search.py b/Chain_Generation_and_Search.py index 728dd3e..4d936a7 100644 --- a/Chain_Generation_and_Search.py +++ b/Chain_Generation_and_Search.py @@ -1,15 +1,7 @@ import numpy as np import random +import subprocess -# Top level Signature Hiearchy is a list -# The ith element of a Hiearchy is a list which contains different implementations that are sorted in a decreasing order of implementation hierarchy -Top_level_Signature_Hierarchy = [ - ["DayOfWeekTrend:yes,DayOfWeekWeights:weighted","DayOfWeekTrend:yes,DayOfWeekWeights:uniform","DayOfWeekTrend:no"], - ["DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no","DayOfYearTrend:no"], - ["HolidayTrend:yes","HolidayTrend:yes"], - ["LongTermTrend:yes","LongTermTrend:no"], - ["SeasonalTrend:yes","SeasonalTrend:no"] -] # Generate a chain given the Top_level_Signature_Hierarchy information # The following function returns a chain of models where for any i, the ith model's model complexity is strictly higher than that of the (i+1)th model @@ -51,12 +43,43 @@ def Chain_Generation(Top_level_Signature_Hierarchy): Candidates_for_increment.remove(increment_ind) return Chain + + +def text_command(args): + """Run a shell command, return its stdout as a String or throw an exception if it fails.""" + + try: + result = subprocess.run(args, text=True, check=True, + stderr=subprocess.STDOUT, stdout=subprocess.PIPE) + + stdout = result.stdout.strip() + return stdout + except subprocess.CalledProcessError as exc: + sys.exit("Error in `mstan`: \"" + exc.output.strip() + "\"") + + +class ModelEvaluator: + def __init__(self, dataFile): + self.dataFile = dataFile + + def score(self, modelPath): + """Return the numerical score for the Stan program at the given filepath""" + stdout_result = text_command(["Rscript", "elpd.R", modelPath, self.dataFile]) + return float(stdout_result.split('\n')[-1].strip()) + + # Compute the ELPD value of a model # model is a list where its ith element is a string that represents the implementation for the ith top-level signature -def ELPD(model): +def ELPD(model, data_file): + print(model[0]) # use the elements of 'model' (type: list) to obtain the full 'name' of the model # Then use 'STAN' to compute ELPD of the model (based on the 'model name' obtained above) - return 0 + model_code_args = ["mstan", "-f", ",".join(model), "concrete-model", "-s", model,] + model_code = text_command(model_code_args) + with open("temp_stanmodel.stan", "w") as f: + f.write(model_code) + result = ModelEvaluator(data_file).score("temp_stanmodel.stan") + return result # Chain is a list whose elements are individual models. Each model is a list that consists of the implementations of the top-level signatures. @@ -66,11 +89,11 @@ def ELPD(model): # and finally returns the model with the highest ELPD value and its ELPD value (among those that the ELPD values were computed) # Suppose that the chain is given as an input and the models are sorted in a decreasing order of model complexity. # i.e. the 1st model in the chain has the highest model complexity and the last model has the lowest complexity. -def Chain_Search(Chain,K,alpha=0.5): +def Chain_Search(Chain,K,data_file_dir, alpha=0.5): n = len(Chain) # if number of models in the chain is smaller than or equal to K, we can compute ELPD values of each model and choose the one with the highest value if n <=K: - ELPD_values = [ELPD(model) for model in Chain] + ELPD_values = [ELPD(model, data_file_dir) for model in Chain] highest_ELPD_model_ind = np.argmax(ELPD_values) best_model, best_ELPD_val = Chain[highest_ELPD_model_ind], ELPD_values[highest_ELPD_model_ind] return best_model, best_ELPD_val @@ -82,23 +105,23 @@ def Chain_Search(Chain,K,alpha=0.5): ELPD_values_obtained = [] while num_ELPD_computed < K and cur_ind < n: # compute the ELPD value of the current iteration's model - cur_iter_ELPD = ELPD(Chain[cur_ind]) + cur_iter_ELPD = ELPD(Chain[cur_ind], data_file_dir) # update the ELPD compute model indices, ELPD values, the nubmer of ELPD values computed obtained respectively ELPD_computed_model_indices.append(cur_ind) ELPD_values_obtained.append(cur_iter_ELPD) num_ELPD_computed+=1 step_size = 1 # set default step size as 1 # if it is neither the 1st nor the 2nd iteration, the step size should be modified. - if not (num_ELPD_computed ==0 or num_ELPD_computed==1): + if not (num_ELPD_computed == 1 or num_ELPD_computed==2): step_size_Uniform = (n-1-cur_ind)/(K-num_ELPD_computed) step_size_LB = step_size_Uniform*(1-alpha) step_size_UB = step_size_Uniform*(1+alpha) - ELPD_slope_cur_iter = abs((ELPD(ELPD_computed_model_indices[-1])-ELPD(ELPD_computed_model_indices[-2])))/(ELPD_computed_model_indices[-1]-ELPD_computed_model_indices[-2]) - ELPD_slope_previous_iter = abs((ELPD(ELPD_computed_model_indices[-2])-ELPD(ELPD_computed_model_indices[-3])))/(ELPD_computed_model_indices[-2]-ELPD_computed_model_indices[-3]) + ELPD_slope_cur_iter = abs((ELPD_values_obtained[-1]-ELPD_values_obtained[-2]))/(ELPD_computed_model_indices[-1]-ELPD_computed_model_indices[-2]) + ELPD_slope_previous_iter = abs((ELPD_values_obtained[-2]-ELPD_values_obtained[-3]))/(ELPD_computed_model_indices[-2]-ELPD_computed_model_indices[-3]) step_size_candidate = step_size_Uniform*(ELPD_slope_previous_iter/ELPD_slope_cur_iter) step_size = min(step_size_UB,min(step_size_LB,step_size_candidate)) # update the current index (which would the index of the model in the next iteration) - cur_iter_ELPD += step_size + cur_ind += step_size # find the best model (the model with the highest ELPD value) among the models whose ELPD values were computed. highest_ELPD_value_model_chain_search_ind = np.argmax(ELPD_values_obtained) Final_best_ELPD_model_ind, Final_best_ELPD_val = ELPD_computed_model_indices[highest_ELPD_value_model_chain_search_ind], ELPD_values_obtained[highest_ELPD_value_model_chain_search_ind] @@ -106,58 +129,79 @@ def Chain_Search(Chain,K,alpha=0.5): return Final_best_model, Final_best_ELPD_val +# Top level Signature Hiearchy is a list +# The ith element of a Hiearchy is a list which contains different implementations that are sorted in a decreasing order of implementation hierarchy +Top_level_Signature_Hierarchy = [ + ["DayOfWeekTrend:yes,DayOfWeekWeights:weighted","DayOfWeekTrend:yes,DayOfWeekWeights:uniform","DayOfWeekTrend:no"], + ["DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no","DayOfYearTrend:no"], + ["HolidayTrend:yes","HolidayTrend:yes"], + ["LongTermTrend:yes","LongTermTrend:no"], + ["SeasonalTrend:yes","SeasonalTrend:no"] +] +chain = Chain_Generation(Top_level_Signature_Hierarchy) -# Original Chain Generation Algorithm (Proposed earlier) +for val in chain: + print("*" * 10) + print(val) -# Simple example -Model_collection =[ - 'Mean:normal,Stddev:lognormal,StddevInformative:no', - 'Mean:normal,Stddev:lognormal,StddevInformative:yes', - 'Mean:normal,Stddev:standard', - 'Mean:standard,Stddev:lognormal,StddevInformative:no', - 'Mean:standard,Stddev:lognormal,StddevInformative:yes', - 'Mean:standard,Stddev:standard' -] +data_file_dir = "examples/birthday/births_usa_1969.json" +K = 3 -# Simple example -Current_Node = ['normal','lognormal,StddevInformative:yes'] -Module_names = ['Mean','Stddev'] -Hierarchy = [{'normal':['standard']}, - {'lognormal,StddevInformative:yes':['standard'],'lognormal,StddevInformative:no':['standard']} - ] - -# Birthday Case study -Current_Node = ['yes,DayOfWeekWeights:weighted','yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:yes','yes','yes','yes'] -Module_names = ['DayOfWeekTrend','DayofYearTrend','HolidayTrend','LongTermTrend','SeasonalTrend'] -Hierarchy = [{'yes,DayOfWeekWeights:weighted':['no'],'yes,DayOfWeekWeights:uniform':['no']}, - {'yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:yes':['no'], - 'yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:no':['no'], - 'yes,DayOfYearHierarchicalVariance:no,DayOfYearNormalVariance:yes':['no'], - 'yes,DayOfYearHierarchicalVariance:no,DayOfYearNormalVariance:no':['no']}, - {'yes':['no']}, - {'yes': ['no']}, - {'yes': ['no']} -] +best_model, best_elpd = Chain_Search(Chain=chain, K=K, data_file_dir=data_file_dir) +print(best_model, best_elpd) + +# Original Chain Generation Algorithm (Proposed earlier) -def Generate_Maximal_Chain(Current_Node,Module_names,Hierarchy): - num_Modules = len(Current_Node) - Chain = [Current_Node] - Chain_length = 1 - Longest_Chain = [Current_Node] - Longest_Chain_length = 1 - for i in range(num_Modules): - if Current_Node[i] in Hierarchy[i]: - for alternative in Hierarchy[i][Current_Node[i]]: - Next_node = [Current_Node[j] if i !=j else alternative for j in range(num_Modules)] - Additional_Chain,Additional_Chain_length = Generate_Maximal_Chain(Next_node,Module_names, Hierarchy) - if Chain_length+Additional_Chain_length > Longest_Chain_length: - Longest_Chain = Chain + Additional_Chain - Longest_Chain_length = Chain_length+Additional_Chain_length - return Longest_Chain, Longest_Chain_length - -Longest_Chain, Longest_Chain_length = Generate_Maximal_Chain(Current_Node,Module_names,Hierarchy) -print("Longest Chain: ",[[Module_names[k]+":"+node[k] for k in range(len(node))] for node in Longest_Chain]) -print("Longest Chain Length: ",Longest_Chain_length) +# Simple example +# Model_collection =[ +# 'Mean:normal,Stddev:lognormal,StddevInformative:no', +# 'Mean:normal,Stddev:lognormal,StddevInformative:yes', +# 'Mean:normal,Stddev:standard', +# 'Mean:standard,Stddev:lognormal,StddevInformative:no', +# 'Mean:standard,Stddev:lognormal,StddevInformative:yes', +# 'Mean:standard,Stddev:standard' +# ] + + +# # Simple example +# Current_Node = ['normal','lognormal,StddevInformative:yes'] +# Module_names = ['Mean','Stddev'] +# Hierarchy = [{'normal':['standard']}, +# {'lognormal,StddevInformative:yes':['standard'],'lognormal,StddevInformative:no':['standard']} +# ] + +# # Birthday Case study +# Current_Node = ['yes,DayOfWeekWeights:weighted','yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:yes','yes','yes','yes'] +# Module_names = ['DayOfWeekTrend','DayofYearTrend','HolidayTrend','LongTermTrend','SeasonalTrend'] +# Hierarchy = [{'yes,DayOfWeekWeights:weighted':['no'],'yes,DayOfWeekWeights:uniform':['no']}, +# {'yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:yes':['no'], +# 'yes,DayOfYearHierarchicalVariance:yes,DayOfYearNormalVariance:no':['no'], +# 'yes,DayOfYearHierarchicalVariance:no,DayOfYearNormalVariance:yes':['no'], +# 'yes,DayOfYearHierarchicalVariance:no,DayOfYearNormalVariance:no':['no']}, +# {'yes':['no']}, +# {'yes': ['no']}, +# {'yes': ['no']} +# ] + +# def Generate_Maximal_Chain(Current_Node,Module_names,Hierarchy): +# num_Modules = len(Current_Node) +# Chain = [Current_Node] +# Chain_length = 1 +# Longest_Chain = [Current_Node] +# Longest_Chain_length = 1 +# for i in range(num_Modules): +# if Current_Node[i] in Hierarchy[i]: +# for alternative in Hierarchy[i][Current_Node[i]]: +# Next_node = [Current_Node[j] if i !=j else alternative for j in range(num_Modules)] +# Additional_Chain,Additional_Chain_length = Generate_Maximal_Chain(Next_node,Module_names, Hierarchy) +# if Chain_length+Additional_Chain_length > Longest_Chain_length: +# Longest_Chain = Chain + Additional_Chain +# Longest_Chain_length = Chain_length+Additional_Chain_length +# return Longest_Chain, Longest_Chain_length + +# Longest_Chain, Longest_Chain_length = Generate_Maximal_Chain(Current_Node,Module_names,Hierarchy) +# print("Longest Chain: ",[[Module_names[k]+":"+node[k] for k in range(len(node))] for node in Longest_Chain]) +# print("Longest Chain Length: ",Longest_Chain_length) From 80025a91a1295bf176f01e3f9298904cfdc5f13f Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 21 Jan 2022 11:23:25 +0900 Subject: [PATCH 07/23] Update chain search code --- Chain_Generation_and_Search.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/Chain_Generation_and_Search.py b/Chain_Generation_and_Search.py index 4d936a7..b1cb9b7 100644 --- a/Chain_Generation_and_Search.py +++ b/Chain_Generation_and_Search.py @@ -1,7 +1,7 @@ import numpy as np import random import subprocess - +import sys # Generate a chain given the Top_level_Signature_Hierarchy information # The following function returns a chain of models where for any i, the ith model's model complexity is strictly higher than that of the (i+1)th model @@ -19,7 +19,7 @@ def Chain_Generation(Top_level_Signature_Hierarchy): Cur_indices = [0 for i in range(m)] Cur_Indices_sum = sum(Cur_indices) # The following list represents the possible increments for the indices - Possible_increments_for_indices = [ len(Top_level_Signature_Hierarchy[i]) - Cur_indices[i] for i in range(m)] + Possible_increments_for_indices = [ len(Top_level_Signature_Hierarchy[i]) - 1 - Cur_indices[i] for i in range(m)] # Candidates for increments (the indices of top-level signatures whose implementations' indices can be increased) Candidates_for_increment = [] for i in range(m): @@ -29,9 +29,9 @@ def Chain_Generation(Top_level_Signature_Hierarchy): Indices_sum_UB = np.sum([len(implementations)-1 for implementations in Top_level_Signature_Hierarchy]) # Create the next model for the chain by randomly increasing the index of exactly one implementation of the previous model wherever possible # If the Chain contains the model with the lowest possible complexity (in which case the Cur_Indices_sum equals the Indices_sum_UB, terminate the Chain generation algorithm) - while Cur_Indices_sum <= Indices_sum_UB: + while Cur_Indices_sum < Indices_sum_UB: # Get the current iteration's model based on the Current indices & Top-level signature hierarchy - cur_iter_model = [Top_level_Signature_Hierarchy[Cur_indices[i]] for i in range(m)] + cur_iter_model = [Top_level_Signature_Hierarchy[i][Cur_indices[i]] for i in range(m)] Chain.append(cur_iter_model) # Randomly increase the index of a particular implementation increment_ind = random.choice(Candidates_for_increment) @@ -71,14 +71,15 @@ def score(self, modelPath): # Compute the ELPD value of a model # model is a list where its ith element is a string that represents the implementation for the ith top-level signature def ELPD(model, data_file): - print(model[0]) + # use the elements of 'model' (type: list) to obtain the full 'name' of the model # Then use 'STAN' to compute ELPD of the model (based on the 'model name' obtained above) - model_code_args = ["mstan", "-f", ",".join(model), "concrete-model", "-s", model,] + model_code_args = ["mstan", "-f", "birthday.m.stan", "concrete-model", "-s", ",".join(model) + ",Regression:glm",] model_code = text_command(model_code_args) with open("temp_stanmodel.stan", "w") as f: f.write(model_code) result = ModelEvaluator(data_file).score("temp_stanmodel.stan") + print(f"model: {','.join(model)} ELPD:{result}") return result @@ -141,10 +142,6 @@ def Chain_Search(Chain,K,data_file_dir, alpha=0.5): chain = Chain_Generation(Top_level_Signature_Hierarchy) -for val in chain: - print("*" * 10) - print(val) - data_file_dir = "examples/birthday/births_usa_1969.json" K = 3 From 873312430667354a969ab14c40fdde028d32adb3 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 28 Jan 2022 09:11:24 +0900 Subject: [PATCH 08/23] update search and add model df --- Chain_Generation_and_Search.py | 7 +- apex-predator-search.py | 6 +- random_models.py | 132 +++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 5 deletions(-) create mode 100644 random_models.py diff --git a/Chain_Generation_and_Search.py b/Chain_Generation_and_Search.py index b1cb9b7..15961d2 100644 --- a/Chain_Generation_and_Search.py +++ b/Chain_Generation_and_Search.py @@ -141,13 +141,14 @@ def Chain_Search(Chain,K,data_file_dir, alpha=0.5): ] chain = Chain_Generation(Top_level_Signature_Hierarchy) - +for v in chain: + print(v) data_file_dir = "examples/birthday/births_usa_1969.json" K = 3 -best_model, best_elpd = Chain_Search(Chain=chain, K=K, data_file_dir=data_file_dir) -print(best_model, best_elpd) +#best_model, best_elpd = Chain_Search(Chain=chain, K=K, data_file_dir=data_file_dir) +#print(best_model, best_elpd) # Original Chain Generation Algorithm (Proposed earlier) diff --git a/apex-predator-search.py b/apex-predator-search.py index 5459f6c..f6d92c2 100644 --- a/apex-predator-search.py +++ b/apex-predator-search.py @@ -49,9 +49,11 @@ def score(self, modelPath): hierarchy_info = [ - ["DayofWeekTrend:Yes,DayofWeekWeights:weighted", "DayofWeekTrend:Yes,DayofWeekWeights:uniform", "DayofWeekTrend:no"], + ["DayofWeekTrend:yes,DayofWeekWeights:weighted", "DayofWeekTrend:yes,DayofWeekWeights:uniform", "DayofWeekTrend:no"], ["DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:yes","DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:no","DayofYearTrend:yes,DayofHierarchicalVariance:no,DayofYearNormalVariance:yes", "DayofYearTrend:no"] - ["HolidayTrend:Yes", "HolidayTrend:No"], + ["HolidayTrend:yes", "HolidayTrend:no"], + ["LongTermTrend:yes", "LongTermTrend:no"] + ["SeasonTrend:yes", "SeasonTrend:no"] #... ] # n, n-1, ... 1 diff --git a/random_models.py b/random_models.py new file mode 100644 index 0000000..d125af2 --- /dev/null +++ b/random_models.py @@ -0,0 +1,132 @@ +from json.tool import main +from sys import implementation +import pandas as pd +import numpy as np +from scipy.stats import linregress + + +elpd_results = { + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 5232.03, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 4578.51, + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 5585.64, + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 6717.52, + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 5258.05, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6910.55, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 2534.39, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6540.03, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 7457.22, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 7812.98, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6935.09, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 7522.39, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 10947.29, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6818.07, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13949.47, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13956.11, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13979.34, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13959.7, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 13127.09, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 11204.72, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 9539.07, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 7179.98, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14076.06, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14526.21, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14620.97, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14599.08, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 4232.97, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 8271.31, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 15299.22, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14598.53, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 7228.88, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14755.47, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 3283.42, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15071.18, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 15313.68, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 11466.83, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 5882.44, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 12213.75, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14929.74, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 1606.94, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15384.86, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 5888.01, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14608.01, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15301.54, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 11698.06, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 2161.10, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14109.15 +} + +columns = ['DayOfWeekWeights', 'HolidayTrend', 'DayOfYearNormalVariance', 'DayOfYearHeirarchicalVariance', 'Regression', 'DayOfWeekTrend', 'SeasonalTrend', 'DayOfYearTrend', 'LongTermTrend', 'elpd'] + +model_list = [] + +for model_string, elpd in elpd_results.items(): + values = [None] * len(columns) + for substr in model_string.split(","): + signature, implementation = substr.split(":") + values[columns.index(signature)] = implementation + values[-1] = elpd + model_list.append(values) + +model_df = pd.DataFrame(model_list, columns=columns) + + + +def row_to_string(row): + dict_form = row.to_dict() + del dict_form["elpd"] + return ",".join(list(map(lambda x: f"{x[0]}:{list(x[1].values())[0]}", dict_form.items()))) + + +def generate_chain(length): + filtered_df = model_df + model_chain = [] + yes_set = set() + selected_model_strs = [] + for _ in range(length): + temp_df = filtered_df + for col in yes_set: + temp_df = temp_df[(temp_df[col] != "no") & (temp_df[col] != None)] + selected_model = temp_df.sample() + string_repr = row_to_string(selected_model) + if string_repr in selected_model_strs: + if temp_df.shape[0] == 1: + break + continue + + selected_model_strs.append(string_repr) + filtered_df = temp_df + model_chain.append([string_repr, selected_model["elpd"].item()]) + yes_set.update(list(selected_model.columns[(selected_model == "yes").iloc[0]])) + + return model_chain + +if __name__ == "__main__": + N = 200 + n_positive = 0 + slope_sum = 0 + failed = 0 + for _ in range(N): + x = [] + y = [] + chain = generate_chain(5) + #print(len(chain)) + for indx, val in enumerate(chain): + x.append(indx) + y.append(val[1]) + + slope = linregress(x, y).slope + if np.isnan(slope): + failed += 1 + else: + slope_sum += slope + print(slope) + if slope > 0: + n_positive += 1 + + print("-" * 10) + print(slope_sum) + print(slope_sum / (N - failed)) + print(n_positive) + + + From e80399d1a7fc7a3d8bcc8f0eccfbd421c5dfa6e8 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Sun, 30 Jan 2022 21:12:57 +0900 Subject: [PATCH 09/23] Create model dataframe module and birthday model df --- birthday_df.csv | 48 ++++++++++++++++++++ birthday_df_create_and_test.py | 80 ++++++++++++++++++++++++++++++++++ elpd_df.py | 73 +++++++++++++++++++++++++++++++ random_models.py | 67 +--------------------------- 4 files changed, 203 insertions(+), 65 deletions(-) create mode 100644 birthday_df.csv create mode 100644 birthday_df_create_and_test.py create mode 100644 elpd_df.py diff --git a/birthday_df.csv b/birthday_df.csv new file mode 100644 index 0000000..04f7eb2 --- /dev/null +++ b/birthday_df.csv @@ -0,0 +1,48 @@ +,DayOfWeekTrend,DayOfYearTrend,HolidayTrend,LongTermTrend,Regression,SeasonalTrend,elpd,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfWeekWeights +0,no,no,no,no,glm,no,5232.03,,, +1,no,yes,no,no,glm,no,4578.51,no,no, +2,no,no,no,no,glm,yes,5585.64,,, +3,no,no,no,yes,glm,no,6717.52,,, +4,no,no,yes,no,glm,no,5258.05,,, +5,yes,no,no,no,glm,no,6910.55,,,uniform +6,yes,no,no,no,glm,no,2534.39,,,weighted +7,yes,yes,no,no,glm,no,6540.03,no,no,uniform +8,yes,yes,no,no,glm,no,7457.22,no,yes,uniform +9,yes,yes,no,no,glm,no,7812.98,yes,no,uniform +10,yes,yes,no,no,glm,no,6935.09,yes,yes,uniform +11,yes,no,no,no,glm,yes,7522.39,,,uniform +12,yes,no,no,yes,glm,no,10947.29,,,uniform +13,yes,no,yes,no,glm,no,6818.07,,,uniform +14,yes,yes,no,yes,glm,no,13949.47,no,no,uniform +15,yes,yes,no,yes,glm,no,13956.11,no,yes,uniform +16,yes,yes,no,yes,glm,no,13979.34,yes,no,uniform +17,yes,yes,no,yes,glm,no,13959.7,yes,yes,uniform +18,yes,no,no,yes,glm,yes,13127.09,,,uniform +19,yes,no,yes,yes,glm,no,11204.72,,,uniform +20,yes,no,no,yes,glm,no,9539.07,,,weighted +21,no,yes,no,yes,glm,no,7179.98,yes,no, +22,yes,yes,no,yes,glm,yes,14076.06,yes,no,uniform +23,yes,yes,yes,yes,glm,no,14526.21,yes,no,uniform +24,yes,yes,no,yes,glm,no,14620.97,yes,no,weighted +25,yes,yes,no,yes,glm,no,14599.08,no,no,weighted +26,yes,yes,no,no,glm,no,4232.97,yes,no,weighted +27,yes,yes,no,yes,glm,yes,8271.31,yes,no,weighted +28,yes,yes,yes,yes,glm,no,15299.22,yes,no,weighted +29,yes,yes,no,yes,glm,no,14598.53,yes,yes,weighted +30,no,yes,yes,yes,glm,no,7228.88,yes,no, +31,yes,yes,yes,yes,glm,no,14755.47,no,no,weighted +32,yes,yes,yes,no,glm,no,3283.42,yes,no,weighted +33,yes,yes,yes,yes,glm,yes,15071.18,yes,no,weighted +34,yes,yes,yes,yes,glm,no,15313.68,yes,yes,weighted +35,yes,no,yes,yes,glm,no,11466.83,,,weighted +36,no,yes,yes,yes,glm,no,5882.44,yes,yes, +37,yes,yes,yes,yes,glm,no,12213.75,yes,yes,uniform +38,yes,yes,yes,yes,glm,no,14929.74,no,yes,weighted +39,yes,yes,yes,no,glm,no,1606.94,yes,yes,weighted +40,yes,yes,yes,yes,glm,yes,15384.86,yes,yes,weighted +41,no,yes,yes,yes,glm,yes,5888.01,yes,yes, +42,yes,yes,yes,yes,glm,yes,14608.01,yes,yes,uniform +43,yes,yes,yes,yes,glm,yes,15301.54,no,yes,weighted +44,yes,yes,no,yes,glm,yes,11698.06,yes,yes,weighted +45,yes,yes,yes,no,glm,yes,2161.1,yes,yes,weighted +46,yes,no,yes,yes,glm,yes,14109.15,,,weighted diff --git a/birthday_df_create_and_test.py b/birthday_df_create_and_test.py new file mode 100644 index 0000000..61cacd1 --- /dev/null +++ b/birthday_df_create_and_test.py @@ -0,0 +1,80 @@ +from elpd_df import * +import pandas as pd + +elpd_results = { + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 5232.03, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 4578.51, + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 5585.64, + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 6717.52, + "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 5258.05, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6910.55, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 2534.39, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6540.03, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 7457.22, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 7812.98, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6935.09, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 7522.39, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 10947.29, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6818.07, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13949.47, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13956.11, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13979.34, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13959.7, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 13127.09, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 11204.72, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 9539.07, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 7179.98, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14076.06, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14526.21, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14620.97, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14599.08, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 4232.97, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 8271.31, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 15299.22, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14598.53, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 7228.88, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14755.47, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 3283.42, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15071.18, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 15313.68, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 11466.83, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 5882.44, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 12213.75, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14929.74, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 1606.94, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15384.86, + "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 5888.01, + "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14608.01, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15301.54, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 11698.06, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 2161.10, + "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14109.15 +} + +columns = ['DayOfWeekWeights', 'HolidayTrend', 'DayOfYearNormalVariance', 'DayOfYearHeirarchicalVariance', 'Regression', 'DayOfWeekTrend', 'SeasonalTrend', 'DayOfYearTrend', 'LongTermTrend', 'elpd'] + +model_list = [] + +for model_string, elpd in elpd_results.items(): + model_dict = model_string_to_dict(model_string) + model_dict["elpd"] = elpd + model_list.append(model_dict) + +df = pd.DataFrame(model_list) + +save_csv(df, "birthday_df.csv") + +# 15301.54 +test_str = "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes" + +print(model_string_to_dict(test_str)) + +# lookup a model +result = search_df(df, model_string_to_dict(test_str)) +print(result) + +print("-" * 10) +# update/add a model and elpd value +df = upsert_model(df, model_string=test_str, elpd=9999999) +print(df) + diff --git a/elpd_df.py b/elpd_df.py new file mode 100644 index 0000000..24e8bd4 --- /dev/null +++ b/elpd_df.py @@ -0,0 +1,73 @@ +import pandas as pd +from collections import defaultdict +from typing import Dict, List, Tuple, Union + + + +def model_string_to_dict(model_string: str) -> Dict[str, str]: + """ + Convert a mstan model string name to a dictionary + """ + sig_df = {} + for substr in model_string.split(","): + signature, implementation = substr.split(":") + sig_df[signature] = implementation + + return sig_df + + +def row_to_string(row: pd.DataFrame) -> Dict: + """ + Convert a dataframe row into a mstan model string name + """ + if len(row) != 1: + raise Exception("Only a single row should be supplied") + dict_form = row.to_dict() + del dict_form["elpd"] + return ",".join(list(map(lambda x: f"{x[0]}:{list(x[1].values())[0]}", dict_form.items()))) + + +def upsert_model(df: pd.DataFrame, model_dict: Dict[str, Union[str, float]] = None, model_string: str = None, elpd: float = None) -> pd.DataFrame: + """ + If the model exists in the dataframe, update the dataframe. + If it does not exist, create the entry. + Returns the actuated dataframe + """ + if not model_dict: + model_dict = model_string_to_dict(model_string) + + if not elpd: + raise Exception("Must provide ELPD value") + + result = search_df(df, model_dict) + if result.empty: + model_dict["elpd"] = elpd + return df.append(model_dict, ignore_index=True) + + df.loc[result.index, "elpd"] = elpd + + return df + + +def search_df(df: pd.DataFrame, model_dict: Dict[str, Union[str, float]] = None, model_string: str = None): + + """ + checks if a given model dict is in a df. returns the row if exists or an empty df if not + """ + if not model_dict: + model_dict = model_string_to_dict(model_string) + + result = df + for key, val in model_dict.items(): + if key == "elpd": + continue + + result = result.loc[result[key] == val] + + return result if not result.empty else pd.DataFrame() + +def save_csv(df: pd.DataFrame, filename): + df.to_csv(filename) + +def read_csv(filename: str): + return pd.read_csv(filename) diff --git a/random_models.py b/random_models.py index d125af2..8912b49 100644 --- a/random_models.py +++ b/random_models.py @@ -3,72 +3,9 @@ import pandas as pd import numpy as np from scipy.stats import linregress +from elpd_df import * - -elpd_results = { - "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 5232.03, - "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 4578.51, - "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 5585.64, - "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 6717.52, - "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 5258.05, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6910.55, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 2534.39, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6540.03, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 7457.22, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 7812.98, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6935.09, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 7522.39, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 10947.29, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 6818.07, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13949.47, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13956.11, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13979.34, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 13959.7, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 13127.09, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 11204.72, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 9539.07, - "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 7179.98, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14076.06, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14526.21, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14620.97, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14599.08, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 4232.97, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 8271.31, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 15299.22, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14598.53, - "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 7228.88, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14755.47, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 3283.42, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15071.18, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 15313.68, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 11466.83, - "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 5882.44, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 12213.75, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:no": 14929.74, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 1606.94, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15384.86, - "DayOfWeekTrend:no,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 5888.01, - "DayOfWeekTrend:yes,DayOfWeekWeights:uniform,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14608.01, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 15301.54, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:no,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 11698.06, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:no,Regression:glm,SeasonalTrend:yes": 2161.10, - "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14109.15 -} - -columns = ['DayOfWeekWeights', 'HolidayTrend', 'DayOfYearNormalVariance', 'DayOfYearHeirarchicalVariance', 'Regression', 'DayOfWeekTrend', 'SeasonalTrend', 'DayOfYearTrend', 'LongTermTrend', 'elpd'] - -model_list = [] - -for model_string, elpd in elpd_results.items(): - values = [None] * len(columns) - for substr in model_string.split(","): - signature, implementation = substr.split(":") - values[columns.index(signature)] = implementation - values[-1] = elpd - model_list.append(values) - -model_df = pd.DataFrame(model_list, columns=columns) - +model_df = read_csv("birthday_df.csv") def row_to_string(row): From d3650777be26f276850bddb304bee5d1969dc8ce Mon Sep 17 00:00:00 2001 From: Dashadower Date: Sun, 30 Jan 2022 21:18:05 +0900 Subject: [PATCH 10/23] update typing --- elpd_df.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elpd_df.py b/elpd_df.py index 24e8bd4..43a1115 100644 --- a/elpd_df.py +++ b/elpd_df.py @@ -16,7 +16,7 @@ def model_string_to_dict(model_string: str) -> Dict[str, str]: return sig_df -def row_to_string(row: pd.DataFrame) -> Dict: +def row_to_string(row: pd.DataFrame) -> str: """ Convert a dataframe row into a mstan model string name """ From f65f1bcd692384f376e0c8703d89750b2f07d18c Mon Sep 17 00:00:00 2001 From: Dashadower Date: Tue, 1 Feb 2022 01:28:39 +0900 Subject: [PATCH 11/23] Move files to directory, create mstan inferface and bayesian probability update algorithm --- birthday_df.csv | 48 ------- .../Chain_Generation_and_Search.py | 0 .../__pycache__/elpd_df.cpython-38.pyc | Bin 0 -> 2476 bytes .../mstan_interface.cpython-38.pyc | Bin 0 -> 1898 bytes .../apex-predator-search.py | 0 .../bayesian_probabilistic_search.py | 94 ++++++++++++++ search_algorithms/birthday_df.csv | 121 ++++++++++++++++++ .../birthday_df_create_and_test.py | 26 ++-- search_algorithms/birthday_df_prob.csv | 121 ++++++++++++++++++ elpd.R => search_algorithms/elpd.R | 4 +- elpd_df.py => search_algorithms/elpd_df.py | 10 +- .../graph_search.py | 0 search_algorithms/mstan_interface.py | 49 +++++++ search_algorithms/prob_1.png | Bin 0 -> 13370 bytes .../random_models.py | 0 15 files changed, 403 insertions(+), 70 deletions(-) delete mode 100644 birthday_df.csv rename Chain_Generation_and_Search.py => search_algorithms/Chain_Generation_and_Search.py (100%) create mode 100644 search_algorithms/__pycache__/elpd_df.cpython-38.pyc create mode 100644 search_algorithms/__pycache__/mstan_interface.cpython-38.pyc rename apex-predator-search.py => search_algorithms/apex-predator-search.py (100%) create mode 100644 search_algorithms/bayesian_probabilistic_search.py create mode 100644 search_algorithms/birthday_df.csv rename birthday_df_create_and_test.py => search_algorithms/birthday_df_create_and_test.py (92%) create mode 100644 search_algorithms/birthday_df_prob.csv rename elpd.R => search_algorithms/elpd.R (86%) rename elpd_df.py => search_algorithms/elpd_df.py (91%) rename graph_search.py => search_algorithms/graph_search.py (100%) create mode 100644 search_algorithms/mstan_interface.py create mode 100644 search_algorithms/prob_1.png rename random_models.py => search_algorithms/random_models.py (100%) diff --git a/birthday_df.csv b/birthday_df.csv deleted file mode 100644 index 04f7eb2..0000000 --- a/birthday_df.csv +++ /dev/null @@ -1,48 +0,0 @@ -,DayOfWeekTrend,DayOfYearTrend,HolidayTrend,LongTermTrend,Regression,SeasonalTrend,elpd,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfWeekWeights -0,no,no,no,no,glm,no,5232.03,,, -1,no,yes,no,no,glm,no,4578.51,no,no, -2,no,no,no,no,glm,yes,5585.64,,, -3,no,no,no,yes,glm,no,6717.52,,, -4,no,no,yes,no,glm,no,5258.05,,, -5,yes,no,no,no,glm,no,6910.55,,,uniform -6,yes,no,no,no,glm,no,2534.39,,,weighted -7,yes,yes,no,no,glm,no,6540.03,no,no,uniform -8,yes,yes,no,no,glm,no,7457.22,no,yes,uniform -9,yes,yes,no,no,glm,no,7812.98,yes,no,uniform -10,yes,yes,no,no,glm,no,6935.09,yes,yes,uniform -11,yes,no,no,no,glm,yes,7522.39,,,uniform -12,yes,no,no,yes,glm,no,10947.29,,,uniform -13,yes,no,yes,no,glm,no,6818.07,,,uniform -14,yes,yes,no,yes,glm,no,13949.47,no,no,uniform -15,yes,yes,no,yes,glm,no,13956.11,no,yes,uniform -16,yes,yes,no,yes,glm,no,13979.34,yes,no,uniform -17,yes,yes,no,yes,glm,no,13959.7,yes,yes,uniform -18,yes,no,no,yes,glm,yes,13127.09,,,uniform -19,yes,no,yes,yes,glm,no,11204.72,,,uniform -20,yes,no,no,yes,glm,no,9539.07,,,weighted -21,no,yes,no,yes,glm,no,7179.98,yes,no, -22,yes,yes,no,yes,glm,yes,14076.06,yes,no,uniform -23,yes,yes,yes,yes,glm,no,14526.21,yes,no,uniform -24,yes,yes,no,yes,glm,no,14620.97,yes,no,weighted -25,yes,yes,no,yes,glm,no,14599.08,no,no,weighted -26,yes,yes,no,no,glm,no,4232.97,yes,no,weighted -27,yes,yes,no,yes,glm,yes,8271.31,yes,no,weighted -28,yes,yes,yes,yes,glm,no,15299.22,yes,no,weighted -29,yes,yes,no,yes,glm,no,14598.53,yes,yes,weighted -30,no,yes,yes,yes,glm,no,7228.88,yes,no, -31,yes,yes,yes,yes,glm,no,14755.47,no,no,weighted -32,yes,yes,yes,no,glm,no,3283.42,yes,no,weighted -33,yes,yes,yes,yes,glm,yes,15071.18,yes,no,weighted -34,yes,yes,yes,yes,glm,no,15313.68,yes,yes,weighted -35,yes,no,yes,yes,glm,no,11466.83,,,weighted -36,no,yes,yes,yes,glm,no,5882.44,yes,yes, -37,yes,yes,yes,yes,glm,no,12213.75,yes,yes,uniform -38,yes,yes,yes,yes,glm,no,14929.74,no,yes,weighted -39,yes,yes,yes,no,glm,no,1606.94,yes,yes,weighted -40,yes,yes,yes,yes,glm,yes,15384.86,yes,yes,weighted -41,no,yes,yes,yes,glm,yes,5888.01,yes,yes, -42,yes,yes,yes,yes,glm,yes,14608.01,yes,yes,uniform -43,yes,yes,yes,yes,glm,yes,15301.54,no,yes,weighted -44,yes,yes,no,yes,glm,yes,11698.06,yes,yes,weighted -45,yes,yes,yes,no,glm,yes,2161.1,yes,yes,weighted -46,yes,no,yes,yes,glm,yes,14109.15,,,weighted diff --git a/Chain_Generation_and_Search.py b/search_algorithms/Chain_Generation_and_Search.py similarity index 100% rename from Chain_Generation_and_Search.py rename to search_algorithms/Chain_Generation_and_Search.py diff --git a/search_algorithms/__pycache__/elpd_df.cpython-38.pyc b/search_algorithms/__pycache__/elpd_df.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2764fb550bb2a4f6df626fbdf8ba5d54852c574 GIT binary patch literal 2476 zcmZ`)OK%)S5bmDWKD>U$BsKzNfGDuRvf;u394k_s@K7*O9+y?nYP{X+9eZBso^`y| z_=0moxp9k;>|_24KVc4BIOWU*<-}J#wuu*IM%^=AU0q%M)mPPzCnuW>&!5ge!LN13 z{-VL#PYr`xX!;Lyf(ah64#(Gutk~|@oMd+7#BRrpy^a^xI<>gosdFY9(nV|ZtwkoB zBTKl#d+v0al5bDQhNwYO7Y#^S;-d(= zy-??+`)HU&vUDFOVVbn9vK6O7MqQn&Fxe_SCG(+5dYmksJ}$lv+HEwwj4oscykN%w z?u4(ipWi>Q3JW9agg>zk>;tE83SQW&?CDfNZrS_aG<`g`(_~kw+z^f^qB(t*aa*t2xyVP@(F$lwhbIX zCu_%6mvb)cL;efUh1>$h_6a+7PS`YC#CzYq;l!i9!4w;V?@TG2u0J``3O-q7ee5L6oR6paZYJm)>oq--W0*whZy~h#`dq`uoSY2Ys;%q+6yAR zhM!8a-2}S^K_<(_oxPsSh&E-dmqx=l(WMvaWEe%I9VDYp?KJYGvz>;Cx(&Rn&ONStjN5K$iGJCvXh-G@@zosW zDqc}eo;??G^Z}#=*DU}GCLmXBy#Qu$Vpa}T75vbmEZw%%T494n)++H0yb`u>URVzq zq-@)Th9|7PnMF9<{of2FzVG|_K%SzFH))f zB+Y$eVae|)d8UudRijGxfhkAYm&DbAXTuzGJByPZtUs1;j3#MOL^wxOon2BeV5)1#&BQsF)zA*8ZkJ zu{V&t9mKJZ1m6mG_;ZNu;}^+F!sh>Ed3OdV33Eb@9AA+H+fgq`(Y&M z#u{^{tm|M`lIFKi(`HMqR`CCdE4n%VR}v}Mw@H_3T#gc0FjKkt2KKrESy@M32r{X@ zz;J9s^(ACKpq=eSW4OM71cNxR7(e7>FQaCO+@9}`QVAba#Ux7a0bf3iS z*I@P#%DOG+ zg02bn55D{P@W6PxKkn{GAHXuc0;i^C{Rzu+!FL;NYu&i&R2fg4WZDF9>9o@57L!9| z0x49L>&Ao&kNAto55U~{&i1IBh;7c)i1G4R$n9aK6Dg*p+MX!QiftvBOh*aJho#K) zXmYWhWQ7)Tz*5njo|+CmKDod%z{swHAg)8()Fa-Tz-`lju2Xp%R#V)%+8mfXhELrE zAqY|pXhyk%G~+J!pa-g90L|J@B^UY53zk<*momk(mkr1GS!yB2YM@mD1csBe7gG~* zrrGydE@H=cO5_8%0a_M`M%;p_2@`l6ppf89*$e1wiu`iI@gjifEX<5JfbAM@^X?^;>%?1%l{VAo3~i?it#C{RUd>TkP$FjX%gs24q#! zJk#<$a5Q%1N6__GCJ65?T3T&V0mxvfgLJt9l5agj_!MR|an|YV);kzkm0s*4q6M3X zBr)wIfg7qE$DJhktzvnzBk!R7dnn%CWJeeU`5_3L*~>y2x^^W+z%LsVw6}$Y!^A>Q zkUupY26I=F+Jz8!nvNarW5Igm{PES?ulfi{7 z^^p8R_Q;>WChVc4Ag{un@ol1jfahy3@YUXz_j|@Eq(gms-46+mB z;dUI_keE=5$rJ-<)_Z^wKY(c|nLI_N8vhv54!S{JsgZrPMag)GTih~jdsAZ1cbYMl zh+^1ZLAEy=H=2zUZcVTt^e;}~f}sQXKEQPacbf)lz`wN0nHC9B=~GZqYaobslZNyr zWHw6q1*}{v@E8-MgiQhr=w%#+-W&>iO)(812_ONMmij%p30m?q5VIS3rt}unwgx~Y zKZe!VX}J3dj$QcxwYwntfqj;G*Yn#$p`l>qNHvis4IF9Zj02Cq!GD|TfB;qkFT>TfPPq0T1-RN0 literal 0 HcmV?d00001 diff --git a/apex-predator-search.py b/search_algorithms/apex-predator-search.py similarity index 100% rename from apex-predator-search.py rename to search_algorithms/apex-predator-search.py diff --git a/search_algorithms/bayesian_probabilistic_search.py b/search_algorithms/bayesian_probabilistic_search.py new file mode 100644 index 0000000..44ac36b --- /dev/null +++ b/search_algorithms/bayesian_probabilistic_search.py @@ -0,0 +1,94 @@ +from operator import mod +import elpd_df +import numpy as np +import random +import pandas as pd +import pathlib +from mstan_interface import calculate_elpd, get_all_model_strings +import matplotlib.pyplot as plt + +def plot_probabilities(probabilities, filename): + x = list(range(len(probabilities))) + plt.figure() + #plt.scatter(x, probabilities, linewidths=1) + plt.plot(x, probabilities) + plt.ylim(bottom=0.0) + plt.savefig(filename) + + +def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_iterations=10): + # model df must contain all the models + model_df = elpd_df.read_csv(model_df_path) + + model_count = model_df.shape[0] + + model_df["probability"] = 1.0 / model_count + + previous_iteration_elpd = None + previons_iteration_model_dict = None + + for iter in range(1, num_iterations + 1): + print("-" * 20) + print(f"iteration {iter}") + draw = model_df.sample(weights=model_df.probability) + draw_string = elpd_df.row_to_string(draw.drop(columns="probability")) + print(f"chose model {draw_string}, with probability", draw.probability.values[0]) + + model_dict = elpd_df.model_string_to_dict(draw_string) + if not np.isnan(elpd_df.search_df(model_df, model_dict).elpd.values[0]): + elpd = elpd_df.search_df(model_df, model_dict).elpd.values[0] + print(f"using saved ELPD value {elpd}") + else: + elpd = calculate_elpd(model_path, draw_string, data_path) + #elpd = random.randint(500, 12000) + print(f"calculated ELPD value {elpd}, saving to df") + model_df = elpd_df.upsert_model(model_df, model_dict, elpd=elpd) + + if iter > 1: + update_arr = np.zeros(model_count) + if elpd > previous_iteration_elpd: + deduction_dict = previons_iteration_model_dict + increment_dict = model_dict + elif elpd < previous_iteration_elpd: + deduction_dict = model_dict + increment_dict = previons_iteration_model_dict + + for key, value in model_dict.items(): + if key not in previons_iteration_model_dict: + continue + + if value != previons_iteration_model_dict[key]: + bad_models = elpd_df.search_df(model_df, {key: deduction_dict[key]}) + num_signatures = len(bad_models) - bad_models.isnull().sum(axis=1) - 1 # remove elpd column + reducted_probs = 0 + for index, n_sigs in zip(num_signatures.index, num_signatures): + deduction_amount = bad_models.loc[[index], "probability"].values[0] / n_sigs + update_arr[index] -= deduction_amount + reducted_probs += deduction_amount + + good_models = elpd_df.search_df(model_df, {key: increment_dict[key]}) + n_good_models = good_models.shape[0] + print(f"deducted {reducted_probs} amount of probability, redistributing to {n_good_models} models") + for index in good_models.index: + update_arr[index] += reducted_probs / n_good_models + + model_df["probability"] += update_arr + + + + + + + print(model_df) + plot_probabilities(model_df.probability, f"prob_{iter}.png") + previous_iteration_elpd = elpd + previons_iteration_model_dict = model_dict + + elpd_df.save_csv(model_df.drop(columns="probability"), "birthday_df_prob.csv") + +if __name__ == "__main__": + example_dir = pathlib.Path(__file__).resolve().parents[1].absolute().joinpath("examples") + birthday_model_path = example_dir.joinpath("birthday/birthday.m.stan") + birthday_data_path = example_dir.joinpath("birthday/births_usa_1969.json") + birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df.csv") + bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=10) diff --git a/search_algorithms/birthday_df.csv b/search_algorithms/birthday_df.csv new file mode 100644 index 0000000..659d7cf --- /dev/null +++ b/search_algorithms/birthday_df.csv @@ -0,0 +1,121 @@ +DayOfWeekTrend,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfYearTrend,HolidayTrend,LongTermTrend,Regression,SeasonalTrend,elpd,DayOfWeekWeights +no,no,no,yes,no,no,glm,no,4578.51, +no,no,no,yes,no,no,glm,yes,, +no,no,no,yes,no,yes,glm,no,, +no,no,no,yes,no,yes,glm,yes,, +no,no,no,yes,yes,no,glm,no,, +no,no,no,yes,yes,no,glm,yes,, +no,no,no,yes,yes,yes,glm,no,, +no,no,no,yes,yes,yes,glm,yes,, +no,no,yes,yes,no,no,glm,no,, +no,no,yes,yes,no,no,glm,yes,, +no,no,yes,yes,no,yes,glm,no,, +no,no,yes,yes,no,yes,glm,yes,, +no,no,yes,yes,yes,no,glm,no,, +no,no,yes,yes,yes,no,glm,yes,, +no,no,yes,yes,yes,yes,glm,no,, +no,no,yes,yes,yes,yes,glm,yes,, +no,yes,no,yes,no,no,glm,no,, +no,yes,no,yes,no,no,glm,yes,, +no,yes,no,yes,no,yes,glm,no,7179.98, +no,yes,no,yes,no,yes,glm,yes,, +no,yes,no,yes,yes,no,glm,no,, +no,yes,no,yes,yes,no,glm,yes,, +no,yes,no,yes,yes,yes,glm,no,7228.88, +no,yes,no,yes,yes,yes,glm,yes,, +no,yes,yes,yes,no,no,glm,no,, +no,yes,yes,yes,no,no,glm,yes,, +no,yes,yes,yes,no,yes,glm,no,, +no,yes,yes,yes,no,yes,glm,yes,, +no,yes,yes,yes,yes,no,glm,no,, +no,yes,yes,yes,yes,no,glm,yes,, +no,yes,yes,yes,yes,yes,glm,no,5882.44, +no,yes,yes,yes,yes,yes,glm,yes,5888.01, +no,,,no,no,no,glm,no,5232.03, +no,,,no,no,no,glm,yes,5585.64, +no,,,no,no,yes,glm,no,6717.52, +no,,,no,no,yes,glm,yes,, +no,,,no,yes,no,glm,no,5258.05, +no,,,no,yes,no,glm,yes,, +no,,,no,yes,yes,glm,no,, +no,,,no,yes,yes,glm,yes,, +yes,no,no,yes,no,no,glm,no,6540.03,uniform +yes,no,no,yes,no,no,glm,yes,,uniform +yes,no,no,yes,no,yes,glm,no,13949.47,uniform +yes,no,no,yes,no,yes,glm,yes,,uniform +yes,no,no,yes,yes,no,glm,no,,uniform +yes,no,no,yes,yes,no,glm,yes,,uniform +yes,no,no,yes,yes,yes,glm,no,,uniform +yes,no,no,yes,yes,yes,glm,yes,,uniform +yes,no,yes,yes,no,no,glm,no,7457.22,uniform +yes,no,yes,yes,no,no,glm,yes,,uniform +yes,no,yes,yes,no,yes,glm,no,13956.11,uniform +yes,no,yes,yes,no,yes,glm,yes,,uniform +yes,no,yes,yes,yes,no,glm,no,,uniform +yes,no,yes,yes,yes,no,glm,yes,,uniform +yes,no,yes,yes,yes,yes,glm,no,,uniform +yes,no,yes,yes,yes,yes,glm,yes,,uniform +yes,yes,no,yes,no,no,glm,no,7812.98,uniform +yes,yes,no,yes,no,no,glm,yes,,uniform +yes,yes,no,yes,no,yes,glm,no,13979.34,uniform +yes,yes,no,yes,no,yes,glm,yes,14076.06,uniform +yes,yes,no,yes,yes,no,glm,no,,uniform +yes,yes,no,yes,yes,no,glm,yes,,uniform +yes,yes,no,yes,yes,yes,glm,no,14526.21,uniform +yes,yes,no,yes,yes,yes,glm,yes,,uniform +yes,yes,yes,yes,no,no,glm,no,6935.09,uniform +yes,yes,yes,yes,no,no,glm,yes,,uniform +yes,yes,yes,yes,no,yes,glm,no,13959.7,uniform +yes,yes,yes,yes,no,yes,glm,yes,,uniform +yes,yes,yes,yes,yes,no,glm,no,,uniform +yes,yes,yes,yes,yes,no,glm,yes,,uniform +yes,yes,yes,yes,yes,yes,glm,no,12213.75,uniform +yes,yes,yes,yes,yes,yes,glm,yes,14608.01,uniform +yes,,,no,no,no,glm,no,6910.55,uniform +yes,,,no,no,no,glm,yes,7522.39,uniform +yes,,,no,no,yes,glm,no,10947.29,uniform +yes,,,no,no,yes,glm,yes,13127.09,uniform +yes,,,no,yes,no,glm,no,6818.07,uniform +yes,,,no,yes,no,glm,yes,,uniform +yes,,,no,yes,yes,glm,no,11204.72,uniform +yes,,,no,yes,yes,glm,yes,,uniform +yes,no,no,yes,no,no,glm,no,,weighted +yes,no,no,yes,no,no,glm,yes,,weighted +yes,no,no,yes,no,yes,glm,no,14599.08,weighted +yes,no,no,yes,no,yes,glm,yes,,weighted +yes,no,no,yes,yes,no,glm,no,,weighted +yes,no,no,yes,yes,no,glm,yes,,weighted +yes,no,no,yes,yes,yes,glm,no,14755.47,weighted +yes,no,no,yes,yes,yes,glm,yes,,weighted +yes,no,yes,yes,no,no,glm,no,,weighted +yes,no,yes,yes,no,no,glm,yes,,weighted +yes,no,yes,yes,no,yes,glm,no,,weighted +yes,no,yes,yes,no,yes,glm,yes,,weighted +yes,no,yes,yes,yes,no,glm,no,,weighted +yes,no,yes,yes,yes,no,glm,yes,,weighted +yes,no,yes,yes,yes,yes,glm,no,14929.74,weighted +yes,no,yes,yes,yes,yes,glm,yes,15301.54,weighted +yes,yes,no,yes,no,no,glm,no,4232.97,weighted +yes,yes,no,yes,no,no,glm,yes,,weighted +yes,yes,no,yes,no,yes,glm,no,14620.97,weighted +yes,yes,no,yes,no,yes,glm,yes,8271.31,weighted +yes,yes,no,yes,yes,no,glm,no,3283.42,weighted +yes,yes,no,yes,yes,no,glm,yes,,weighted +yes,yes,no,yes,yes,yes,glm,no,15299.22,weighted +yes,yes,no,yes,yes,yes,glm,yes,15071.18,weighted +yes,yes,yes,yes,no,no,glm,no,,weighted +yes,yes,yes,yes,no,no,glm,yes,,weighted +yes,yes,yes,yes,no,yes,glm,no,14598.53,weighted +yes,yes,yes,yes,no,yes,glm,yes,11698.06,weighted +yes,yes,yes,yes,yes,no,glm,no,1606.94,weighted +yes,yes,yes,yes,yes,no,glm,yes,2161.1,weighted +yes,yes,yes,yes,yes,yes,glm,no,15313.68,weighted +yes,yes,yes,yes,yes,yes,glm,yes,15384.86,weighted +yes,,,no,no,no,glm,no,2534.39,weighted +yes,,,no,no,no,glm,yes,,weighted +yes,,,no,no,yes,glm,no,9539.07,weighted +yes,,,no,no,yes,glm,yes,,weighted +yes,,,no,yes,no,glm,no,,weighted +yes,,,no,yes,no,glm,yes,,weighted +yes,,,no,yes,yes,glm,no,11466.83,weighted +yes,,,no,yes,yes,glm,yes,14109.15,weighted diff --git a/birthday_df_create_and_test.py b/search_algorithms/birthday_df_create_and_test.py similarity index 92% rename from birthday_df_create_and_test.py rename to search_algorithms/birthday_df_create_and_test.py index 61cacd1..8258594 100644 --- a/birthday_df_create_and_test.py +++ b/search_algorithms/birthday_df_create_and_test.py @@ -1,5 +1,8 @@ from elpd_df import * +from mstan_interface import get_all_model_strings import pandas as pd +import numpy as np +import pathlib elpd_results = { "DayOfWeekTrend:no,DayOfYearTrend:no,HolidayTrend:no,LongTermTrend:no,Regression:glm,SeasonalTrend:no": 5232.03, @@ -51,30 +54,21 @@ "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearTrend:no,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes": 14109.15 } -columns = ['DayOfWeekWeights', 'HolidayTrend', 'DayOfYearNormalVariance', 'DayOfYearHeirarchicalVariance', 'Regression', 'DayOfWeekTrend', 'SeasonalTrend', 'DayOfYearTrend', 'LongTermTrend', 'elpd'] +all_model_strings = get_all_model_strings(pathlib.Path(__file__).resolve().parents[1].absolute().joinpath("examples/birthday/birthday.m.stan")) + model_list = [] -for model_string, elpd in elpd_results.items(): +for model_string in all_model_strings: model_dict = model_string_to_dict(model_string) - model_dict["elpd"] = elpd + model_dict["elpd"] = np.nan model_list.append(model_dict) df = pd.DataFrame(model_list) -save_csv(df, "birthday_df.csv") - -# 15301.54 -test_str = "DayOfWeekTrend:yes,DayOfWeekWeights:weighted,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:yes,DayOfYearTrend:yes,HolidayTrend:yes,LongTermTrend:yes,Regression:glm,SeasonalTrend:yes" - -print(model_string_to_dict(test_str)) +for model_string, elpd in elpd_results.items(): + df = upsert_model(df, model_string=model_string, elpd=elpd) -# lookup a model -result = search_df(df, model_string_to_dict(test_str)) -print(result) +save_csv(df, "birthday_df.csv") -print("-" * 10) -# update/add a model and elpd value -df = upsert_model(df, model_string=test_str, elpd=9999999) -print(df) diff --git a/search_algorithms/birthday_df_prob.csv b/search_algorithms/birthday_df_prob.csv new file mode 100644 index 0000000..659d7cf --- /dev/null +++ b/search_algorithms/birthday_df_prob.csv @@ -0,0 +1,121 @@ +DayOfWeekTrend,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfYearTrend,HolidayTrend,LongTermTrend,Regression,SeasonalTrend,elpd,DayOfWeekWeights +no,no,no,yes,no,no,glm,no,4578.51, +no,no,no,yes,no,no,glm,yes,, +no,no,no,yes,no,yes,glm,no,, +no,no,no,yes,no,yes,glm,yes,, +no,no,no,yes,yes,no,glm,no,, +no,no,no,yes,yes,no,glm,yes,, +no,no,no,yes,yes,yes,glm,no,, +no,no,no,yes,yes,yes,glm,yes,, +no,no,yes,yes,no,no,glm,no,, +no,no,yes,yes,no,no,glm,yes,, +no,no,yes,yes,no,yes,glm,no,, +no,no,yes,yes,no,yes,glm,yes,, +no,no,yes,yes,yes,no,glm,no,, +no,no,yes,yes,yes,no,glm,yes,, +no,no,yes,yes,yes,yes,glm,no,, +no,no,yes,yes,yes,yes,glm,yes,, +no,yes,no,yes,no,no,glm,no,, +no,yes,no,yes,no,no,glm,yes,, +no,yes,no,yes,no,yes,glm,no,7179.98, +no,yes,no,yes,no,yes,glm,yes,, +no,yes,no,yes,yes,no,glm,no,, +no,yes,no,yes,yes,no,glm,yes,, +no,yes,no,yes,yes,yes,glm,no,7228.88, +no,yes,no,yes,yes,yes,glm,yes,, +no,yes,yes,yes,no,no,glm,no,, +no,yes,yes,yes,no,no,glm,yes,, +no,yes,yes,yes,no,yes,glm,no,, +no,yes,yes,yes,no,yes,glm,yes,, +no,yes,yes,yes,yes,no,glm,no,, +no,yes,yes,yes,yes,no,glm,yes,, +no,yes,yes,yes,yes,yes,glm,no,5882.44, +no,yes,yes,yes,yes,yes,glm,yes,5888.01, +no,,,no,no,no,glm,no,5232.03, +no,,,no,no,no,glm,yes,5585.64, +no,,,no,no,yes,glm,no,6717.52, +no,,,no,no,yes,glm,yes,, +no,,,no,yes,no,glm,no,5258.05, +no,,,no,yes,no,glm,yes,, +no,,,no,yes,yes,glm,no,, +no,,,no,yes,yes,glm,yes,, +yes,no,no,yes,no,no,glm,no,6540.03,uniform +yes,no,no,yes,no,no,glm,yes,,uniform +yes,no,no,yes,no,yes,glm,no,13949.47,uniform +yes,no,no,yes,no,yes,glm,yes,,uniform +yes,no,no,yes,yes,no,glm,no,,uniform +yes,no,no,yes,yes,no,glm,yes,,uniform +yes,no,no,yes,yes,yes,glm,no,,uniform +yes,no,no,yes,yes,yes,glm,yes,,uniform +yes,no,yes,yes,no,no,glm,no,7457.22,uniform +yes,no,yes,yes,no,no,glm,yes,,uniform +yes,no,yes,yes,no,yes,glm,no,13956.11,uniform +yes,no,yes,yes,no,yes,glm,yes,,uniform +yes,no,yes,yes,yes,no,glm,no,,uniform +yes,no,yes,yes,yes,no,glm,yes,,uniform +yes,no,yes,yes,yes,yes,glm,no,,uniform +yes,no,yes,yes,yes,yes,glm,yes,,uniform +yes,yes,no,yes,no,no,glm,no,7812.98,uniform +yes,yes,no,yes,no,no,glm,yes,,uniform +yes,yes,no,yes,no,yes,glm,no,13979.34,uniform +yes,yes,no,yes,no,yes,glm,yes,14076.06,uniform +yes,yes,no,yes,yes,no,glm,no,,uniform +yes,yes,no,yes,yes,no,glm,yes,,uniform +yes,yes,no,yes,yes,yes,glm,no,14526.21,uniform +yes,yes,no,yes,yes,yes,glm,yes,,uniform +yes,yes,yes,yes,no,no,glm,no,6935.09,uniform +yes,yes,yes,yes,no,no,glm,yes,,uniform +yes,yes,yes,yes,no,yes,glm,no,13959.7,uniform +yes,yes,yes,yes,no,yes,glm,yes,,uniform +yes,yes,yes,yes,yes,no,glm,no,,uniform +yes,yes,yes,yes,yes,no,glm,yes,,uniform +yes,yes,yes,yes,yes,yes,glm,no,12213.75,uniform +yes,yes,yes,yes,yes,yes,glm,yes,14608.01,uniform +yes,,,no,no,no,glm,no,6910.55,uniform +yes,,,no,no,no,glm,yes,7522.39,uniform +yes,,,no,no,yes,glm,no,10947.29,uniform +yes,,,no,no,yes,glm,yes,13127.09,uniform +yes,,,no,yes,no,glm,no,6818.07,uniform +yes,,,no,yes,no,glm,yes,,uniform +yes,,,no,yes,yes,glm,no,11204.72,uniform +yes,,,no,yes,yes,glm,yes,,uniform +yes,no,no,yes,no,no,glm,no,,weighted +yes,no,no,yes,no,no,glm,yes,,weighted +yes,no,no,yes,no,yes,glm,no,14599.08,weighted +yes,no,no,yes,no,yes,glm,yes,,weighted +yes,no,no,yes,yes,no,glm,no,,weighted +yes,no,no,yes,yes,no,glm,yes,,weighted +yes,no,no,yes,yes,yes,glm,no,14755.47,weighted +yes,no,no,yes,yes,yes,glm,yes,,weighted +yes,no,yes,yes,no,no,glm,no,,weighted +yes,no,yes,yes,no,no,glm,yes,,weighted +yes,no,yes,yes,no,yes,glm,no,,weighted +yes,no,yes,yes,no,yes,glm,yes,,weighted +yes,no,yes,yes,yes,no,glm,no,,weighted +yes,no,yes,yes,yes,no,glm,yes,,weighted +yes,no,yes,yes,yes,yes,glm,no,14929.74,weighted +yes,no,yes,yes,yes,yes,glm,yes,15301.54,weighted +yes,yes,no,yes,no,no,glm,no,4232.97,weighted +yes,yes,no,yes,no,no,glm,yes,,weighted +yes,yes,no,yes,no,yes,glm,no,14620.97,weighted +yes,yes,no,yes,no,yes,glm,yes,8271.31,weighted +yes,yes,no,yes,yes,no,glm,no,3283.42,weighted +yes,yes,no,yes,yes,no,glm,yes,,weighted +yes,yes,no,yes,yes,yes,glm,no,15299.22,weighted +yes,yes,no,yes,yes,yes,glm,yes,15071.18,weighted +yes,yes,yes,yes,no,no,glm,no,,weighted +yes,yes,yes,yes,no,no,glm,yes,,weighted +yes,yes,yes,yes,no,yes,glm,no,14598.53,weighted +yes,yes,yes,yes,no,yes,glm,yes,11698.06,weighted +yes,yes,yes,yes,yes,no,glm,no,1606.94,weighted +yes,yes,yes,yes,yes,no,glm,yes,2161.1,weighted +yes,yes,yes,yes,yes,yes,glm,no,15313.68,weighted +yes,yes,yes,yes,yes,yes,glm,yes,15384.86,weighted +yes,,,no,no,no,glm,no,2534.39,weighted +yes,,,no,no,no,glm,yes,,weighted +yes,,,no,no,yes,glm,no,9539.07,weighted +yes,,,no,no,yes,glm,yes,,weighted +yes,,,no,yes,no,glm,no,,weighted +yes,,,no,yes,no,glm,yes,,weighted +yes,,,no,yes,yes,glm,no,11466.83,weighted +yes,,,no,yes,yes,glm,yes,14109.15,weighted diff --git a/elpd.R b/search_algorithms/elpd.R similarity index 86% rename from elpd.R rename to search_algorithms/elpd.R index 68d76ce..763644c 100755 --- a/elpd.R +++ b/search_algorithms/elpd.R @@ -27,8 +27,8 @@ print(cmdstan_version()) #' Model and fit model <- cmdstan_model(stan_file = stanfile, quiet=TRUE) -fit <- model$sample(data = standata, iter_warmup=200, iter_sampling=200, - chains=10, parallel_chains=10, seed=1) +fit <- model$sample(data = standata, iter_warmup=100, iter_sampling=100, + chains=4, parallel_chains=4, seed=1) loo <- fit$loo() elpd_estimate <- loo$estimates['elpd_loo', 'Estimate'] diff --git a/elpd_df.py b/search_algorithms/elpd_df.py similarity index 91% rename from elpd_df.py rename to search_algorithms/elpd_df.py index 43a1115..10a911c 100644 --- a/elpd_df.py +++ b/search_algorithms/elpd_df.py @@ -20,10 +20,12 @@ def row_to_string(row: pd.DataFrame) -> str: """ Convert a dataframe row into a mstan model string name """ - if len(row) != 1: + if row.shape[0] != 1: raise Exception("Only a single row should be supplied") - dict_form = row.to_dict() - del dict_form["elpd"] + + dict_form = row[row.columns[~row.isnull().any()]].to_dict() + if "elpd" in dict_form: + del dict_form["elpd"] return ",".join(list(map(lambda x: f"{x[0]}:{list(x[1].values())[0]}", dict_form.items()))) @@ -67,7 +69,7 @@ def search_df(df: pd.DataFrame, model_dict: Dict[str, Union[str, float]] = None, return result if not result.empty else pd.DataFrame() def save_csv(df: pd.DataFrame, filename): - df.to_csv(filename) + df.to_csv(filename, index=False) def read_csv(filename: str): return pd.read_csv(filename) diff --git a/graph_search.py b/search_algorithms/graph_search.py similarity index 100% rename from graph_search.py rename to search_algorithms/graph_search.py diff --git a/search_algorithms/mstan_interface.py b/search_algorithms/mstan_interface.py new file mode 100644 index 0000000..cfcc647 --- /dev/null +++ b/search_algorithms/mstan_interface.py @@ -0,0 +1,49 @@ +import subprocess +import sys +import tempfile + + +def text_command(args): + """Run a shell command, return its stdout as a String or throw an exception if it fails.""" + + try: + result = subprocess.run(args, text=True, check=True, + stderr=subprocess.STDOUT, stdout=subprocess.PIPE) + + stdout = result.stdout.strip() + return stdout + except subprocess.CalledProcessError as exc: + sys.exit("Error running shell command: \"" + exc.output.strip() + "\"") + + +class ModelEvaluator: + def __init__(self, dataFile): + self.dataFile = dataFile + + def score(self, modelPath): + """Return the numerical score for the Stan program at the given filepath""" + stdout_result = text_command(["Rscript", "elpd.R", modelPath, self.dataFile]) + return float(stdout_result.split('\n')[-1].strip()) + + +# Compute the ELPD value of a model +# model is a list where its ith element is a string that represents the implementation for the ith top-level signature +def calculate_elpd(model_file, model_string, data_file): + + # use the elements of 'model' (type: list) to obtain the full 'name' of the model + # Then use 'STAN' to compute ELPD of the model (based on the 'model name' obtained above) + model_code_args = ["mstan", "-f", model_file, "concrete-model", "-s", model_string,] + model_code = text_command(model_code_args) + with tempfile.NamedTemporaryFile(suffix=".stan", mode="w") as f: + f.write(model_code) + f.flush() + result = ModelEvaluator(data_file).score(f.name) + return result + #print(f"model: {','.join(model)} ELPD:{result}") + + +def get_all_model_strings(model_file): + # return a list of all model strings given a mstan model file + model_code_args = ["mstan", "-f", model_file, "list-all-models"] + models = text_command(model_code_args) + return models.split("\n") diff --git a/search_algorithms/prob_1.png b/search_algorithms/prob_1.png new file mode 100644 index 0000000000000000000000000000000000000000..e67ff1ee48ef996a984d4086fbc18428c2f25903 GIT binary patch literal 13370 zcmeHu2UL`2+V)rzE3slhFbXJLDbm4+3L{NHdQ<5jMS2_17();R8Jbd6Is-G(%fJju z0AZx}G9V&#fRQpF3^V-q>+XJEvb*26|7Mf(pZ(9CGYKb@w>;1NT=#Wf*Oic~TB^IY zAK8vVp?0CwF6y9A8}Fh}8{U7t6}}_YI`9#`$a-8d^w4#&@$kOsW{uLk>EU|Y#pAa9 zt>a$SZtnIj&Jv=MqT<5G?L0hO-Q~o@oc`+47vO&TbypDo0|0bb;LWiSslyXSv+~UJ+;u& zG4A5l2H|UbJvS7YC+i}!4{_{$^>J;GKpfw{o-zfo<^8bMv9I%m6J^~mn-dIrdu_~joJ5hfiB&FQA3l88Mul4R*%Zd3jE{FA$x@wC zw`};NSvpsHl&^A0`APevhPyNh_1^90kvj(t9OxyC@SlJ`<{Vt|luJZ5jsI|@GRb=E z=n(AR)|04>u+KlDzJ^^swDC9!b@%4ici{&=e)}7LU~{GXXjP3cr42WJxgB{@``93} zu>Ab|HfpWO?hR))Rr)z!FOg>Zm-nMk$7hn57p``X?=`@m8{7!Ld}e2a+@?{B*pzY7 zHyYkYQK(ciwx{f*^RP$=>gheWO6dR;B>wQ9fB7F&;^BQL@9CZwy#=~>_v6~ybhh08E>Y&rt#_;U zzv6J1Y)i||&DFE8NWSpm34Fr16&<>7_vT*`5fO24a?<17+uD((dtm2Q&!uT2;uwn_ zL~cq}CT#j@cVg7pXd}4Bf=U#O<3muDcA)`Iw5vs^=MM zXk{!o2ya25N^5;sO?QRvo;-PSYH||2&9t|-7k+hnaJ41(zTLrM+wET8Xu;FzbOXsF znptf>uxQ?cJ8B9AvJoF!zH^Q$p7K{5(!k6Kd~K#1dKB zYpO%waadTxj911ZBGp-OeQg0Q{P=QR4*65p%!9f)2dIO^N;*dS#Z@)wL7$ZTf&wY` zF%3@-54+XHNlPXAr%$hb3=W=29+~@m;N5;GGl^8z_8aAfF=`;Uk?~GOp)mW@a259orMV2eZ7%OzLV~(3uXLeVZ?`{%o?en3%R!hJ^Jd6zb{iHumB3=xhSPKvVOv zprD{+5IexBzgRQm;sqi7oF}{@79`UGn?@aYLtcBvRHt#FT|}`nHM=I16PKH-8O*`c z)s!d?__0DNFPC!~tT3vnshM91aImv`(oNGly&t=1{8NGwrrg27p`)iKb8RgEvqG`Z z4+##wIsXwYb?%(tk;mCNIU2`~9b1Z(*}8SB4!%QAPtSgN zmfBZnA9>DsAoIhATQmJ7+U{fT@*KNzO)@T?co8jR0Jl`z!6|xORW+DvcW`R=px^Ao zM2oF5(}-1{ycPnFnc+=ZNL=kIA*HpowS}0$t?zD&+OxbNA*VDPT2mwkVd(Iu!{_1q z8p6K^r2UTQ_|IIb{wb=|vTk={lA;M&9 zUl`Uq`_-$fD=RB-66L(6pH>2PtdroN)DG-ECg-6!1_1wLoA?RTY;S=ogTa7()Nyek zTv1m?0O-TdA!lhE+zOd#>ap3y#rhZwX8VH+@R9Rd%rq;orW`CPEX)KfiBP+!o1vw5 z=1>QKDjym))Nf@pR<}4DGi?dw0%-K~i7ymnstSOWetG;SCM!|rb} zp8n?R9?M_jlkOCsz@eRe3>J&^@N$ZP6o6YGc^x~xwp2pS9aS7x#P(}f$uC|woW`{wAytH2Hl+2ks{NDELI+V|N6NzX=O}( zNlrl_XL$G~bKFNEu-+J|QzLpw=DmTo5p{6?{?n&UsouYTUtk3OHSeQf$vraXQ=3cv>LOI3sO-$l<^C;yc%6nhILK$8+ zFvw_(m&TQr8Y(4g@7lEsb{XnRpc(VNY6D|=nP$nu?D?VTFyNyZRH?p!0qtXbq{Hpo z*CmhW0V>ML$xVyh*|~G)X?Z3iG#U!8uhPHdoKs)w5gOnZW?`*fTqtqY#c zZnPo#$dMy$(6l>STfYsij+h^?SZOf0uBVrV#bQIA3B&%zS^tbdof4WoxfxaUv6|++ zHDD{?2WnsF??;FJnFIe@Dc7&=ANgMCo4Lc_9)#}Th8DeV!)&P>k0o-4B(bAc45SFU8 z{wGwaZ8LXgcQ-C83td@RDJUX>;6TW}Q?)tDKA)b)h?-^h^%+B{2<+bupUH{V6tBh@ z5tEL#wwGO9`j<}Z(M*)ndwt`oZF5qQBZP(^)&JtgYyM!GC~xCrUcKd>Im znUa3y!uyi24N$TpVUuQ>0=F68C7RH;qi=FzT)OZVT)oe$%&5u%~9J_=8 zJ=L-L1oPZWp$F}0UAe{;K6!n8eIB1b-JBUH)5S(bN2mQACR0RT826B4)N2RoXliOA z?I^coFk6)ea~0`Di`1YB+ ze{CfD6+ZmGxzy7~+wRJHO}D_=8+DcWxXkRNQWc$9gO&bzIyxtUt6!A4k2|chSprNE zijJja3Y?zk#; zmAS@6Q*liPpcPT`0fgG9KMc{kD|{CUF-$MCj`*|d*RDnA=NhLmXx$dqEzHbvfw(R$ zdHLUd{rwg}At7ByN1QMP+fD0?lX5KdnlVP}2p^2>p8-NcEVQcMizb+D$id-s&CIAr zr{uZAfdwE#*`x>Q1qecSLu#NAnQq`ZeE6ESb~xaUVoYLUF06#BjKbL2VONR7aVvxM6q6f=wRw41!e$~&ezqkgTqzq)jzxa=FLC49{~evdF2dw%Ob=Wp zdbVpyPriSBzwgz}`{scw)|g3kiz5rZ@$vBpg0dU<5mtEgq%y(?*E4ULVR zrgIaI5o2 zPEAeeSX+-itpoKWy+d8c);7b8zOeFf&PT`L!(8c!h_`&r0%+ARu4@3~FmJBdq=Sz2 zq*eTb$ooGEXugb0eVphM5kCzw2bZT1ZAXTy>pq84W_aLY3iMV#<1RScH%?exMCIV| zUG55dV)0&&4i|5r+mW)8C9r+>?=q&%7c zqsi*$I%PZ}qjWL4E!Nk+3_U&yN%>II8s}gR$+*HUVWEm_-+uC&%j;saD+?G{UcH{R&{R^2vfj4+M@(?7h%LR!vcK20LkvJ1(e?) z%4ota!(rw@N=DA@<4>5Q&tRJ;hJ$%X+3xUsZ+SuAy~~Sm=JfQ-AS=Zw2a$lxq(j3v zAf&H`jS;iJ!#JB&R%V1P^kX=P7#FIUlsI>Os0r=t>PoK;=5QQuO2noDtph${zcAJa zL?fnoTvtWqeuX&!L`B}o7!zmoAE3a9+wcsrmr4$7 ze;^5XNy8Io*JpmkU?Q?Mff!f&n;4956YyYKRz{MJo?a(k^$@7IALe`-fb5;=1!*Kn z&htu2N(#t90wc~PCarcJUu2k41}C@Jcmjvw;*)f(#B3lyb~kRE9|Cz$wAb;!1zod2 zUO*;sH3U^KV-m^NK$Zu&o2>#I#291@Q3n3=F9w}`10pcB#Y^*1ckTb5B^enRnanqD z-bBTLz!92621XO~IXEcOyJF-RGFos1eO7_V!F;j8*Vh*yi1zX_Su~6<2!MG7TNDq?O;TLkO>+#Y<%z)WJ16m> zo0Zoqu=S})J|DKRfCF?XuT?7}(l9X7geS6y_N)HKAP+JH z8-9y$L;fh6)_viFrZm-~o>g64{jhFC04`Q#6}g1+`Lz>uDW&ync*x}xL@%kpy0bxi zX}1hlOxF(hurGGH2%0!ZNrT|f)+Td-WVZE+KkD6PG1BLkvm~Z7)t_F8aHhceau|qC z9}Mm7pBt$=^-3L@DemOVhFAZ~JVcQlR~M^HT?RcVJ3G4#ERebxsb9H2|2%mS7!R1M z;bbH&SE!DTj{1SJ9G#fBGzz?-bSyoUxZz(ilZpxN0S^GywBn^*^kH5JP)Y_h7IZmi z9lVv?Z0AabJ`1GqN+2ip%kEi1;c#P@E&sqkYWhtE68_MaooWtTlAs1GGzze+suJ(O z?eNbYm2$kds=@pKOgI4XG@DU?&;Qn82aVdzRiNYAElu|VRRiO7YIc^pB#=H>DJUQi zQ9ax;Q086~w9bxV#aL$na05|kXl^z*{!L4sWu1-#ci`unu*P3dxX`{GzZ3=QB(8IL za~1H~OcKc)M6;=x89jUZERZ0y5@gQWuwLK)s>owf2OtRSN5qnak7_t~kz?`g-j^{k z2aAe}1^0&7O?6~}YAWZylw&3#F8&fugF*6YM+Sa{3+c<3{s73px{+WgLW)pDMa7`Q zTY)>4GKxqs?UNO6jDV+jyl2R7!lS|-Jwgokd*`dTDNRlK$G^EM!^OSu)Kw+{tk}P` zvSSs-M-Opx>w~uWUwQ<1V)h;$*PTn`tp{ejgh15w)qshog#BhDI$^<0^1rQ~qe`3r z4duXRwD^ynyM0yoU_r@9U3dl&@-}ycVwO;bSFb)#NZ{JFdw2RHj@`Q(M&=&*J_2#% zFW36`D|2Sk4QFROPz#t2ti$;t(V)j><>$M-vqW^LDZZ?Nf_D7!N&xss0#ab)N2qBMbrY6cm`uN|qb?j;$9>AuKQeqvua*Q0Donv@&%`1&V4~n#bMO)-EcMD|m$xp-8 zeP!+$yaGZ(>Tt`x3f}@49YAF0XlXePv+32-R?2jkF|?AEa_`;S0B-&ij7Zqj5;AM1 z9A*`M574nu{8yl`rluSc6?}ETiU2nbd1~-|S$tveSJt-Vz`XqI>?=fXdRMw;eBqsu z%Lh0)9fwGmI0@S@vjS(T90Un+P0Ms(gN;jFUw}yjlm2uZOgk6`&Y~!YG!iDyi`G_V z6%H;WEt`(OuA@>_|Nj$N`7JlT2meOWf9bX?*X|Wopr^NYxCz;tZXO*K1=6$K>_Ax` zIVh+hkdlws_ix@@fP{k3H(yWmFd|DH*~8r3HEA0=A>I^@TFC zK)D?>5=Y<4!S#|~+CaWiF^EWMSrMp?#IxSM{mIYIuMo^ue_`-;ARRR9>hsR>&G4%! z8cnwra1g|}Q|PHLc^AFo-@IBTnu@{}JgB~jI8MW^hA>#bRG5tNV=D%`8ReEV{M@sA}W9?1xiOA_c6sDJvNYl=tLAt zH(jF@=!;eN8VOVs3~+U_T}5W*gibqD`{Ai~&m`Q+OedXBihx*hDiTp+O?@@^Bjqrw zg*NY=zR1bR!L6DCI9$4PNiJ}$G%q$ojMLH6+K?br?Pqqli#DI9Z!qS7kF_=eDdwFu zW`LF}_`mZN@RlG;mEHj+Os^Dqoi7rkXxidrdu+g9MTG*;jkHJLn>ej)BcZi2A}w~t zh7^^H=X6Kkti=sUoT)66s(;ea-a|~!&qoje!by3Kh`$0(8F?tE z^9voeZwuUV6$?A3r>mRt@Zt8xWMzx~VkbmH1tXFdO^Frj&M_+J7#I*n8&~=j9TYK6 z1=Y@B?n4a{@)@oUE^;5&LhI=3$AIiy=(}J8kVF|@0FC53$m&7w!!Cw-SEwLUixPV- z90$x8B%Jn59173*)ME}|&=3KtU|9bVqHi_ccPcAmv1${?1We{SAa*OwM}%jlrCERA zR|hF9hWKQhV#aypTP#{(HUj)1ZjIQ}5(-wnBhLOR7?r9CBD z{u*a8h-r^n1_p9K!$cB{9Z+vT@)qCh5KHYyQ1sVd9tt8Ob#~Ob^gEuFx2mG$_N5`* zz9086m()04>1UY-%w58Oh}H=Su_*9-zzLdy{<8G(lUnwCy2sAS$y~Dy)*3o@+(O$%)7y@j%a1?c&oN@!1{X<&s zh54I&Xr8kJvWkl>DoC0J1O!3&OK5(yv{>^ivcYJYXVEt(39z63Mg;!!l5BkIrEQY$ zW#k0|KUec)7Zm6sTb<0FOcsQEgcOI;=jV`AqT2!e@*8_*i^xHf)jVw1BSbASV9{jA zsiYR-AOjzDr)$b{ok?ED$J zapO!TnEdH^3yYs&Aj+)c2{P?3ap~3$oZp9;3$^F~oL%g(4DXfVCm28|5Gb41bkEr_ zHpM93ni4Zvr_9cilb7!t9K_ouWAud#@Clg$8kuulv}gnKYY+_rzK*0CT>?gq`VBwe z^k`OGPVIPBXm0TM!6R{qm(kcvdPxxXv13J`M0O4gB-G-?+aV1gjP87_&&h+mEh?(#LI$YZ46@2YZ=lfKQU=iF|)JI>R6x&tcHByf^J&l z`|UTj9{xM7TS9G+C5MQCNChAbM^iK8!(BN{mKVA(x3C1hs{0)3jyP-bQEPO2)eMnKw> zX97d_Pv-4%RERA&_owcm3&9(+eP3^m0wYk3SA1B}HqKBu9SR&!RC6zlaoPH3i;b@x z|D04}iaj@d??pr!Mn)Fy9TJ*=@C{T}JA20&P-*~Am&6|I7;r9u_`;EL4MK78Iq){7 zT3)Dd&RFK^AoF$0NVqZ{yfLSN(rdxHkF~)Q+K-NnWv8dtnn&ibo#IN(!JsezYadc2 zNIX(bE<2bf&=9IL5+0ok9v$ukNlT4hW$T9h$ja3ch9EP>Xi?#?hh(gqTq4$xApH>a z6M7N~sk=2rm2sbE?Km~pDi!#!BV-hTAZV;C1RpX1#}yWUJxgb6VLj@Oc~3yH1PRc> zFcL?kC|EH2x)(ciwUKDy7sKPEYFH+)ese-B?;n zr|dhWH9@5ZmLN91fNuK4GutG@>w*6~QPSUN%nAc)MGa)m2w2j6>f|NQoH>qOiu>fU zAeFrG&K7cSeh4Q(a>tluYQ|SQ0aqIv8(|^lOq=?^Jqaa!_GYj7CMse2dKkoS*nv=< z2!|)Fp2{KqE6?vhp-U8cbFfH0l_ zlR}|*Qa!uVa)x$@RXhhu4e>$f3~lRc&Fd3D$d~lN3IQ`#SGS^$~f92n1x09F!;+(xfBLhJ=xB28~c z`{gWcOapM1u4VE%hBg{&QCwnR?puRj0Ul>hb)$%M-P`n44wxe*aBF=K8jcW~xX4VA z8fYR4yzEc`ejz;qA%rFN}=D$aOJV2Mxjr1Bb2< zUY0S1BOrL1nU$3X0;+XUNr?#@*)Bwn#65prWekppndQ76B&2JBK@_a6t^&Zsk{0=J z1VS^rYicy7k?K9jYOdldI$YC~w9Bo%yntHobnxjJn1F9(xaEY5X7VlCV#wClL zCkFZW^eFwM)3kt|WP`k-KFoDe2+9U4uiY#+Y4{pjJAkD0h+~}w7Dena7j`Y8%zCmI z-6p^$gL7gn0%b9pwoisQ?&j!a(*wrkyC4pi16?1Nn))+LYqkw+D{*^pi4C1g`e9@& zm@TJ?(!f1X(=C}i5MGyyU~tKRTPOrHa!M*^oVp-$+Z)Ac71>Rj%t+>5@P-5T*xD$z z_eVYCa>2|pv~5a2SKr20&UN9)kf$UwlLJ@nnOyw2-JMu#BeFZq s3)n@)=-iiJ?!TwU{HOo_<=y&*=U8b|L;cuBcq0`0lGa81g&W`hA8bI}-T(jq literal 0 HcmV?d00001 diff --git a/random_models.py b/search_algorithms/random_models.py similarity index 100% rename from random_models.py rename to search_algorithms/random_models.py From 7b0577cb81d18f2c06abd1ede2c295abe7124118 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 4 Feb 2022 10:46:38 +0900 Subject: [PATCH 12/23] Update probability search and add score based algorithm --- .../Bandit_approach_Score_based_Algorithm.py | 60 +++++++++++++++++++ .../bayesian_probabilistic_search.py | 18 ++++-- 2 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 search_algorithms/Bandit_approach_Score_based_Algorithm.py diff --git a/search_algorithms/Bandit_approach_Score_based_Algorithm.py b/search_algorithms/Bandit_approach_Score_based_Algorithm.py new file mode 100644 index 0000000..49fc938 --- /dev/null +++ b/search_algorithms/Bandit_approach_Score_based_Algorithm.py @@ -0,0 +1,60 @@ +# Bandit like approach +# In every iteration, the score for each terminal implementation of top-level signature is caculated +# Score of a terminal implementation := the mean of the ELPD values of the models that contain the implementation +# For each model in the model space, the score of the model is the sum of score of its terminal implementations + +import numpy as np +import itertools +import random + +# Top_level_signature_implementations: 2-dimensional list +# the ith element is a list that contains all possible terminal implementation of the ith top-level signature + +# Each model can be represented as an index vector +# e.g. model m = [2,3,0,1,3] the terminal implementations of model m are the 3rd, 4th, 1st, 2nd, 4th implementations of the top-level signatures + + +def ELPD_compute(model): + return 0 + +def Bandit_approach_Score_based_search(Top_level_signature_implementations,k): + # n: the number of top-level signatures + n = len(Top_level_signature_implementations) + # Create model list + Model_list = [model for model in itertools.product([range(len(ith_signature)) for ith_signature in Top_level_signature_implementations])] + # create lists to save the models whose ELPD values are evaluated and their respective ELPD values + Models_evaluated = [] + ELPD_vals = [] + sum_ELPDs = 0 + # create lists to save the following two values for each terminal implementation + # (1) the sum of ELPD values of the models that contain the terminal implementation (2) the number of models (whose ELPD values are evaluated) that contain the terminal implementation + sum_ELPD_values_for_each_implementation =[[0 for terminal_implementation in ith_signature] for ith_signature in Top_level_signature_implementations] + total_num_of_occurences_for_each_implementation = [[0 for terminal_implementation in ith_signature] for ith_signature in Top_level_signature_implementations] + for iteration_ind in range(k): + # calculate the average ELPD value for a terminal implementation + avg_ELPD = sum_ELPDs/(iteration_ind) if iteration_ind >0 else 1 + # compute the score of each terminal implementation + Score_cur_iteration = [[ + sum_ELPD_values_for_each_implementation[i][j]/total_num_of_occurences_for_each_implementation[i][j] if total_num_of_occurences_for_each_implementation[i][j] > 0 else avg_ELPD + for j in range(Top_level_signature_implementations[i])] + for i in range(len(Top_level_signature_implementations)) + ] + # Compute the score of each model by adding up the score of its terminal implementations + Score_each_model = [ sum([Score_cur_iteration[i][model[i]] for i in range(n)]) for model in Model_list] + Model_score_sum = sum(Score_each_model) + # Compute the selection probability of each model by normalizing the scores + Selection_prob_each_model = [model_score/Model_score_sum for model_score in Score_each_model] + # Randomly draw one model based on the selection probabilities + cur_iteration_model_ind = random.choices(range(len(Model_list)), weights=Selection_prob_each_model, cum_weights=None, k=1)[0] + # update Models_evaluated, ELPD_vals, sum_ELPDs + Models_evaluated.append(cur_iteration_model_ind) + cur_iter_model_ELPD = ELPD_compute(Model_list[cur_iteration_model_ind]) + ELPD_vals.append(cur_iter_model_ELPD) + sum_ELPDs += cur_iter_model_ELPD + # update sum_ELPD_values_for_each_implementation, total_num_of_occurences_for_each_implementation + for i in range(n): + implementation_ind = Model_list[cur_iteration_model_ind][i] + sum_ELPD_values_for_each_implementation[i][implementation_ind] = sum_ELPD_values_for_each_implementation[i][implementation_ind] + cur_iter_model_ELPD + total_num_of_occurences_for_each_implementation[i][implementation_ind] = total_num_of_occurences_for_each_implementation[i][implementation_ind]+1 + # Return the model selected at the final iteration + return Model_list[Models_evaluated[-1]] \ No newline at end of file diff --git a/search_algorithms/bayesian_probabilistic_search.py b/search_algorithms/bayesian_probabilistic_search.py index 44ac36b..381efaa 100644 --- a/search_algorithms/bayesian_probabilistic_search.py +++ b/search_algorithms/bayesian_probabilistic_search.py @@ -7,14 +7,18 @@ from mstan_interface import calculate_elpd, get_all_model_strings import matplotlib.pyplot as plt -def plot_probabilities(probabilities, filename): - x = list(range(len(probabilities))) +def plot_probabilities(df, filename): + x = list(range(len(df.probability))) plt.figure() #plt.scatter(x, probabilities, linewidths=1) - plt.plot(x, probabilities) + plt.plot(x, df.probability) plt.ylim(bottom=0.0) plt.savefig(filename) + plt.figure() + plt.scatter(df.probability, df.elpd) + plt.savefig("elpd_" + filename) + def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_iterations=10): # model df must contain all the models @@ -80,15 +84,17 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera print(model_df) - plot_probabilities(model_df.probability, f"prob_{iter}.png") + plot_probabilities(model_df, f"prob_{iter}.png") previous_iteration_elpd = elpd previons_iteration_model_dict = model_dict elpd_df.save_csv(model_df.drop(columns="probability"), "birthday_df_prob.csv") + elpd_df.save_csv(model_df, "bayesian_update_results.csv") + if __name__ == "__main__": example_dir = pathlib.Path(__file__).resolve().parents[1].absolute().joinpath("examples") birthday_model_path = example_dir.joinpath("birthday/birthday.m.stan") birthday_data_path = example_dir.joinpath("birthday/births_usa_1969.json") - birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df.csv") - bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=10) + birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df_prob.csv") + bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=20) From 883178b99a05cab2ad618391a1f7cb00e6a31a4c Mon Sep 17 00:00:00 2001 From: Dashadower Date: Sat, 5 Feb 2022 17:03:11 +0900 Subject: [PATCH 13/23] Update bayesian probabilistic search to include plots --- .../bayesian_probabilistic_search.py | 77 +++++++++++-- search_algorithms/birthday_df_prob.csv | 104 +++++++++--------- search_algorithms/prob_1.png | Bin 13370 -> 0 bytes 3 files changed, 121 insertions(+), 60 deletions(-) delete mode 100644 search_algorithms/prob_1.png diff --git a/search_algorithms/bayesian_probabilistic_search.py b/search_algorithms/bayesian_probabilistic_search.py index 381efaa..3b39dc9 100644 --- a/search_algorithms/bayesian_probabilistic_search.py +++ b/search_algorithms/bayesian_probabilistic_search.py @@ -1,4 +1,6 @@ from operator import mod + +from torch import mode import elpd_df import numpy as np import random @@ -6,18 +8,68 @@ import pathlib from mstan_interface import calculate_elpd, get_all_model_strings import matplotlib.pyplot as plt +from sklearn.linear_model import LinearRegression -def plot_probabilities(df, filename): - x = list(range(len(df.probability))) + +def plot_probabilities(df, iteration): plt.figure() + x = list(range(len(df.probability))) #plt.scatter(x, probabilities, linewidths=1) plt.plot(x, df.probability) plt.ylim(bottom=0.0) - plt.savefig(filename) + plt.xlabel('model index') + plt.ylabel('probability') + plt.savefig(f"model_pmf_{iteration}.png") + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,10)) + fig.suptitle(f"Iteration {iteration}: prob-ELPD plot") + ax1.set_title("only selected models") + ax1.set_xlabel("probability") + ax1.set_ylabel("ELPD") + filtered = df[~df.elpd.isna() & df.selected] + linear_regressor = LinearRegression() + linear_regressor.fit(filtered.probability.values.reshape(-1, 1), filtered.elpd.values.reshape(-1, 1)) + ax1.scatter(filtered.probability, filtered.elpd) + ax1.plot(filtered.probability, linear_regressor.predict(filtered.probability.values.reshape(-1, 1)), color="red") + + ax2.set_title("all models present in cached results") + ax2.set_xlabel("probability") + ax2.set_ylabel("ELPD") + filtered = df[~df.elpd.isna()] + linear_regressor = LinearRegression() + linear_regressor.fit(filtered.probability.values.reshape(-1, 1), filtered.elpd.values.reshape(-1, 1)) + ax2.scatter(filtered.probability, filtered.elpd) + ax2.plot(filtered.probability, linear_regressor.predict(filtered.probability.values.reshape(-1, 1)), color="red") + fig.savefig(f"prob-epld_plot_{iteration}.png") + + +def plot_signatures(df): + # plot elpd, prob for each signature + df = df[~df.elpd.isna() & df.selected] + + signatures = list(df.drop(columns=["elpd", "probability", "selected"]).columns) + + for signature in signatures: + fig, (elpd_ax, ax2) = plt.subplots(1, 2, figsize=(15, 10)) + fig.suptitle(f"elpd-prob plot for signature {signature}, using only selected models") + filtered = df.loc[:, ["elpd", "probability", signature]] + res = filtered.groupby(signature) + #prob_ax = ax1.twinx() + res.mean().plot.bar(ax=elpd_ax, secondary_y="probability") + elpd_ax.set_title("mean of elpd and prob") + + #filtered.set_index("probability").groupby(signature).elpd.plot(ax=ax2, legend=True, style=".", ms=20) + for name, group in res: + ax2.scatter(x=group.probability, y=group.elpd, label=name) + + ax2.legend() + ax2.set_xlabel("probability") + ax2.set_ylabel("ELPD") + #res.plot(, x="probability", y="elpd", ax=ax2, legend=True) + + fig.tight_layout() + fig.savefig(f"sigplot_{signature}.png") - plt.figure() - plt.scatter(df.probability, df.elpd) - plt.savefig("elpd_" + filename) def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_iterations=10): @@ -27,6 +79,7 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera model_count = model_df.shape[0] model_df["probability"] = 1.0 / model_count + model_df["selected"] = False previous_iteration_elpd = None previons_iteration_model_dict = None @@ -35,14 +88,20 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera print("-" * 20) print(f"iteration {iter}") draw = model_df.sample(weights=model_df.probability) + model_df.loc[draw.index, "selected"] = True + + draw_string = elpd_df.row_to_string(draw.drop(columns="probability")) print(f"chose model {draw_string}, with probability", draw.probability.values[0]) model_dict = elpd_df.model_string_to_dict(draw_string) + del model_dict["selected"] + draw_string = ",".join([f"{key}:{val}" for key, val in model_dict.items()]) if not np.isnan(elpd_df.search_df(model_df, model_dict).elpd.values[0]): elpd = elpd_df.search_df(model_df, model_dict).elpd.values[0] print(f"using saved ELPD value {elpd}") else: + print("calculating elpd value...") elpd = calculate_elpd(model_path, draw_string, data_path) #elpd = random.randint(500, 12000) print(f"calculated ELPD value {elpd}, saving to df") @@ -84,12 +143,14 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera print(model_df) - plot_probabilities(model_df, f"prob_{iter}.png") + plot_probabilities(model_df, iter) previous_iteration_elpd = elpd previons_iteration_model_dict = model_dict - elpd_df.save_csv(model_df.drop(columns="probability"), "birthday_df_prob.csv") + elpd_df.save_csv(model_df.drop(columns=["probability", "selected"]), "birthday_df_prob.csv") elpd_df.save_csv(model_df, "bayesian_update_results.csv") + + plot_signatures(model_df) if __name__ == "__main__": diff --git a/search_algorithms/birthday_df_prob.csv b/search_algorithms/birthday_df_prob.csv index 659d7cf..b243dda 100644 --- a/search_algorithms/birthday_df_prob.csv +++ b/search_algorithms/birthday_df_prob.csv @@ -1,74 +1,74 @@ DayOfWeekTrend,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfYearTrend,HolidayTrend,LongTermTrend,Regression,SeasonalTrend,elpd,DayOfWeekWeights no,no,no,yes,no,no,glm,no,4578.51, -no,no,no,yes,no,no,glm,yes,, -no,no,no,yes,no,yes,glm,no,, +no,no,no,yes,no,no,glm,yes,2387.24057250632, +no,no,no,yes,no,yes,glm,no,7125.9942028813, no,no,no,yes,no,yes,glm,yes,, no,no,no,yes,yes,no,glm,no,, -no,no,no,yes,yes,no,glm,yes,, +no,no,no,yes,yes,no,glm,yes,2920.39632324059, no,no,no,yes,yes,yes,glm,no,, -no,no,no,yes,yes,yes,glm,yes,, -no,no,yes,yes,no,no,glm,no,, -no,no,yes,yes,no,no,glm,yes,, +no,no,no,yes,yes,yes,glm,yes,5311.27329352511, +no,no,yes,yes,no,no,glm,no,5442.89912071062, +no,no,yes,yes,no,no,glm,yes,4399.63544787709, no,no,yes,yes,no,yes,glm,no,, no,no,yes,yes,no,yes,glm,yes,, no,no,yes,yes,yes,no,glm,no,, -no,no,yes,yes,yes,no,glm,yes,, -no,no,yes,yes,yes,yes,glm,no,, -no,no,yes,yes,yes,yes,glm,yes,, +no,no,yes,yes,yes,no,glm,yes,5604.70609557286, +no,no,yes,yes,yes,yes,glm,no,7197.64047500906, +no,no,yes,yes,yes,yes,glm,yes,7162.06930181051, no,yes,no,yes,no,no,glm,no,, no,yes,no,yes,no,no,glm,yes,, no,yes,no,yes,no,yes,glm,no,7179.98, -no,yes,no,yes,no,yes,glm,yes,, -no,yes,no,yes,yes,no,glm,no,, -no,yes,no,yes,yes,no,glm,yes,, +no,yes,no,yes,no,yes,glm,yes,7379.32679859518, +no,yes,no,yes,yes,no,glm,no,5455.05228337242, +no,yes,no,yes,yes,no,glm,yes,5674.10498513126, no,yes,no,yes,yes,yes,glm,no,7228.88, no,yes,no,yes,yes,yes,glm,yes,, -no,yes,yes,yes,no,no,glm,no,, -no,yes,yes,yes,no,no,glm,yes,, -no,yes,yes,yes,no,yes,glm,no,, -no,yes,yes,yes,no,yes,glm,yes,, -no,yes,yes,yes,yes,no,glm,no,, +no,yes,yes,yes,no,no,glm,no,5232.3273276717, +no,yes,yes,yes,no,no,glm,yes,4406.46562859129, +no,yes,yes,yes,no,yes,glm,no,6703.35639574843, +no,yes,yes,yes,no,yes,glm,yes,6859.312500912, +no,yes,yes,yes,yes,no,glm,no,5257.83852126472, no,yes,yes,yes,yes,no,glm,yes,, no,yes,yes,yes,yes,yes,glm,no,5882.44, no,yes,yes,yes,yes,yes,glm,yes,5888.01, no,,,no,no,no,glm,no,5232.03, no,,,no,no,no,glm,yes,5585.64, no,,,no,no,yes,glm,no,6717.52, -no,,,no,no,yes,glm,yes,, +no,,,no,no,yes,glm,yes,7249.78141003667, no,,,no,yes,no,glm,no,5258.05, -no,,,no,yes,no,glm,yes,, -no,,,no,yes,yes,glm,no,, -no,,,no,yes,yes,glm,yes,, +no,,,no,yes,no,glm,yes,5522.14783298078, +no,,,no,yes,yes,glm,no,6760.55418752675, +no,,,no,yes,yes,glm,yes,7283.7324312413, yes,no,no,yes,no,no,glm,no,6540.03,uniform -yes,no,no,yes,no,no,glm,yes,,uniform +yes,no,no,yes,no,no,glm,yes,3344.3654692436,uniform yes,no,no,yes,no,yes,glm,no,13949.47,uniform -yes,no,no,yes,no,yes,glm,yes,,uniform -yes,no,no,yes,yes,no,glm,no,,uniform -yes,no,no,yes,yes,no,glm,yes,,uniform +yes,no,no,yes,no,yes,glm,yes,13950.2470267613,uniform +yes,no,no,yes,yes,no,glm,no,7489.57698442313,uniform +yes,no,no,yes,yes,no,glm,yes,4003.44946683213,uniform yes,no,no,yes,yes,yes,glm,no,,uniform yes,no,no,yes,yes,yes,glm,yes,,uniform yes,no,yes,yes,no,no,glm,no,7457.22,uniform yes,no,yes,yes,no,no,glm,yes,,uniform yes,no,yes,yes,no,yes,glm,no,13956.11,uniform yes,no,yes,yes,no,yes,glm,yes,,uniform -yes,no,yes,yes,yes,no,glm,no,,uniform -yes,no,yes,yes,yes,no,glm,yes,,uniform -yes,no,yes,yes,yes,yes,glm,no,,uniform +yes,no,yes,yes,yes,no,glm,no,7250.42854438999,uniform +yes,no,yes,yes,yes,no,glm,yes,7323.31146498067,uniform +yes,no,yes,yes,yes,yes,glm,no,8838.90163837759,uniform yes,no,yes,yes,yes,yes,glm,yes,,uniform yes,yes,no,yes,no,no,glm,no,7812.98,uniform -yes,yes,no,yes,no,no,glm,yes,,uniform +yes,yes,no,yes,no,no,glm,yes,7588.18829799233,uniform yes,yes,no,yes,no,yes,glm,no,13979.34,uniform yes,yes,no,yes,no,yes,glm,yes,14076.06,uniform yes,yes,no,yes,yes,no,glm,no,,uniform -yes,yes,no,yes,yes,no,glm,yes,,uniform +yes,yes,no,yes,yes,no,glm,yes,7337.73019675242,uniform yes,yes,no,yes,yes,yes,glm,no,14526.21,uniform -yes,yes,no,yes,yes,yes,glm,yes,,uniform +yes,yes,no,yes,yes,yes,glm,yes,9900.42594930185,uniform yes,yes,yes,yes,no,no,glm,no,6935.09,uniform -yes,yes,yes,yes,no,no,glm,yes,,uniform +yes,yes,yes,yes,no,no,glm,yes,7187.80978134713,uniform yes,yes,yes,yes,no,yes,glm,no,13959.7,uniform -yes,yes,yes,yes,no,yes,glm,yes,,uniform -yes,yes,yes,yes,yes,no,glm,no,,uniform -yes,yes,yes,yes,yes,no,glm,yes,,uniform +yes,yes,yes,yes,no,yes,glm,yes,14029.7491348079,uniform +yes,yes,yes,yes,yes,no,glm,no,6851.6486517034,uniform +yes,yes,yes,yes,yes,no,glm,yes,7360.56737317448,uniform yes,yes,yes,yes,yes,yes,glm,no,12213.75,uniform yes,yes,yes,yes,yes,yes,glm,yes,14608.01,uniform yes,,,no,no,no,glm,no,6910.55,uniform @@ -76,23 +76,23 @@ yes,,,no,no,no,glm,yes,7522.39,uniform yes,,,no,no,yes,glm,no,10947.29,uniform yes,,,no,no,yes,glm,yes,13127.09,uniform yes,,,no,yes,no,glm,no,6818.07,uniform -yes,,,no,yes,no,glm,yes,,uniform +yes,,,no,yes,no,glm,yes,7332.93411844658,uniform yes,,,no,yes,yes,glm,no,11204.72,uniform -yes,,,no,yes,yes,glm,yes,,uniform -yes,no,no,yes,no,no,glm,no,,weighted +yes,,,no,yes,yes,glm,yes,13576.7239549785,uniform +yes,no,no,yes,no,no,glm,no,2975.31032486582,weighted yes,no,no,yes,no,no,glm,yes,,weighted yes,no,no,yes,no,yes,glm,no,14599.08,weighted -yes,no,no,yes,no,yes,glm,yes,,weighted -yes,no,no,yes,yes,no,glm,no,,weighted -yes,no,no,yes,yes,no,glm,yes,,weighted +yes,no,no,yes,no,yes,glm,yes,14603.4590747121,weighted +yes,no,no,yes,yes,no,glm,no,2464.57690507132,weighted +yes,no,no,yes,yes,no,glm,yes,3250.81680876078,weighted yes,no,no,yes,yes,yes,glm,no,14755.47,weighted yes,no,no,yes,yes,yes,glm,yes,,weighted yes,no,yes,yes,no,no,glm,no,,weighted -yes,no,yes,yes,no,no,glm,yes,,weighted -yes,no,yes,yes,no,yes,glm,no,,weighted -yes,no,yes,yes,no,yes,glm,yes,,weighted -yes,no,yes,yes,yes,no,glm,no,,weighted -yes,no,yes,yes,yes,no,glm,yes,,weighted +yes,no,yes,yes,no,no,glm,yes,2590.62592762889,weighted +yes,no,yes,yes,no,yes,glm,no,9675.48221267936,weighted +yes,no,yes,yes,no,yes,glm,yes,8293.60265791466,weighted +yes,no,yes,yes,yes,no,glm,no,1391.68730130601,weighted +yes,no,yes,yes,yes,no,glm,yes,1581.56156941945,weighted yes,no,yes,yes,yes,yes,glm,no,14929.74,weighted yes,no,yes,yes,yes,yes,glm,yes,15301.54,weighted yes,yes,no,yes,no,no,glm,no,4232.97,weighted @@ -100,11 +100,11 @@ yes,yes,no,yes,no,no,glm,yes,,weighted yes,yes,no,yes,no,yes,glm,no,14620.97,weighted yes,yes,no,yes,no,yes,glm,yes,8271.31,weighted yes,yes,no,yes,yes,no,glm,no,3283.42,weighted -yes,yes,no,yes,yes,no,glm,yes,,weighted +yes,yes,no,yes,yes,no,glm,yes,1391.03697537675,weighted yes,yes,no,yes,yes,yes,glm,no,15299.22,weighted yes,yes,no,yes,yes,yes,glm,yes,15071.18,weighted -yes,yes,yes,yes,no,no,glm,no,,weighted -yes,yes,yes,yes,no,no,glm,yes,,weighted +yes,yes,yes,yes,no,no,glm,no,3345.78672775113,weighted +yes,yes,yes,yes,no,no,glm,yes,3139.79714994284,weighted yes,yes,yes,yes,no,yes,glm,no,14598.53,weighted yes,yes,yes,yes,no,yes,glm,yes,11698.06,weighted yes,yes,yes,yes,yes,no,glm,no,1606.94,weighted @@ -114,8 +114,8 @@ yes,yes,yes,yes,yes,yes,glm,yes,15384.86,weighted yes,,,no,no,no,glm,no,2534.39,weighted yes,,,no,no,no,glm,yes,,weighted yes,,,no,no,yes,glm,no,9539.07,weighted -yes,,,no,no,yes,glm,yes,,weighted -yes,,,no,yes,no,glm,no,,weighted -yes,,,no,yes,no,glm,yes,,weighted +yes,,,no,no,yes,glm,yes,13585.3897712745,weighted +yes,,,no,yes,no,glm,no,2900.40294535745,weighted +yes,,,no,yes,no,glm,yes,1438.26047055841,weighted yes,,,no,yes,yes,glm,no,11466.83,weighted yes,,,no,yes,yes,glm,yes,14109.15,weighted diff --git a/search_algorithms/prob_1.png b/search_algorithms/prob_1.png deleted file mode 100644 index e67ff1ee48ef996a984d4086fbc18428c2f25903..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 13370 zcmeHu2UL`2+V)rzE3slhFbXJLDbm4+3L{NHdQ<5jMS2_17();R8Jbd6Is-G(%fJju z0AZx}G9V&#fRQpF3^V-q>+XJEvb*26|7Mf(pZ(9CGYKb@w>;1NT=#Wf*Oic~TB^IY zAK8vVp?0CwF6y9A8}Fh}8{U7t6}}_YI`9#`$a-8d^w4#&@$kOsW{uLk>EU|Y#pAa9 zt>a$SZtnIj&Jv=MqT<5G?L0hO-Q~o@oc`+47vO&TbypDo0|0bb;LWiSslyXSv+~UJ+;u& zG4A5l2H|UbJvS7YC+i}!4{_{$^>J;GKpfw{o-zfo<^8bMv9I%m6J^~mn-dIrdu_~joJ5hfiB&FQA3l88Mul4R*%Zd3jE{FA$x@wC zw`};NSvpsHl&^A0`APevhPyNh_1^90kvj(t9OxyC@SlJ`<{Vt|luJZ5jsI|@GRb=E z=n(AR)|04>u+KlDzJ^^swDC9!b@%4ici{&=e)}7LU~{GXXjP3cr42WJxgB{@``93} zu>Ab|HfpWO?hR))Rr)z!FOg>Zm-nMk$7hn57p``X?=`@m8{7!Ld}e2a+@?{B*pzY7 zHyYkYQK(ciwx{f*^RP$=>gheWO6dR;B>wQ9fB7F&;^BQL@9CZwy#=~>_v6~ybhh08E>Y&rt#_;U zzv6J1Y)i||&DFE8NWSpm34Fr16&<>7_vT*`5fO24a?<17+uD((dtm2Q&!uT2;uwn_ zL~cq}CT#j@cVg7pXd}4Bf=U#O<3muDcA)`Iw5vs^=MM zXk{!o2ya25N^5;sO?QRvo;-PSYH||2&9t|-7k+hnaJ41(zTLrM+wET8Xu;FzbOXsF znptf>uxQ?cJ8B9AvJoF!zH^Q$p7K{5(!k6Kd~K#1dKB zYpO%waadTxj911ZBGp-OeQg0Q{P=QR4*65p%!9f)2dIO^N;*dS#Z@)wL7$ZTf&wY` zF%3@-54+XHNlPXAr%$hb3=W=29+~@m;N5;GGl^8z_8aAfF=`;Uk?~GOp)mW@a259orMV2eZ7%OzLV~(3uXLeVZ?`{%o?en3%R!hJ^Jd6zb{iHumB3=xhSPKvVOv zprD{+5IexBzgRQm;sqi7oF}{@79`UGn?@aYLtcBvRHt#FT|}`nHM=I16PKH-8O*`c z)s!d?__0DNFPC!~tT3vnshM91aImv`(oNGly&t=1{8NGwrrg27p`)iKb8RgEvqG`Z z4+##wIsXwYb?%(tk;mCNIU2`~9b1Z(*}8SB4!%QAPtSgN zmfBZnA9>DsAoIhATQmJ7+U{fT@*KNzO)@T?co8jR0Jl`z!6|xORW+DvcW`R=px^Ao zM2oF5(}-1{ycPnFnc+=ZNL=kIA*HpowS}0$t?zD&+OxbNA*VDPT2mwkVd(Iu!{_1q z8p6K^r2UTQ_|IIb{wb=|vTk={lA;M&9 zUl`Uq`_-$fD=RB-66L(6pH>2PtdroN)DG-ECg-6!1_1wLoA?RTY;S=ogTa7()Nyek zTv1m?0O-TdA!lhE+zOd#>ap3y#rhZwX8VH+@R9Rd%rq;orW`CPEX)KfiBP+!o1vw5 z=1>QKDjym))Nf@pR<}4DGi?dw0%-K~i7ymnstSOWetG;SCM!|rb} zp8n?R9?M_jlkOCsz@eRe3>J&^@N$ZP6o6YGc^x~xwp2pS9aS7x#P(}f$uC|woW`{wAytH2Hl+2ks{NDELI+V|N6NzX=O}( zNlrl_XL$G~bKFNEu-+J|QzLpw=DmTo5p{6?{?n&UsouYTUtk3OHSeQf$vraXQ=3cv>LOI3sO-$l<^C;yc%6nhILK$8+ zFvw_(m&TQr8Y(4g@7lEsb{XnRpc(VNY6D|=nP$nu?D?VTFyNyZRH?p!0qtXbq{Hpo z*CmhW0V>ML$xVyh*|~G)X?Z3iG#U!8uhPHdoKs)w5gOnZW?`*fTqtqY#c zZnPo#$dMy$(6l>STfYsij+h^?SZOf0uBVrV#bQIA3B&%zS^tbdof4WoxfxaUv6|++ zHDD{?2WnsF??;FJnFIe@Dc7&=ANgMCo4Lc_9)#}Th8DeV!)&P>k0o-4B(bAc45SFU8 z{wGwaZ8LXgcQ-C83td@RDJUX>;6TW}Q?)tDKA)b)h?-^h^%+B{2<+bupUH{V6tBh@ z5tEL#wwGO9`j<}Z(M*)ndwt`oZF5qQBZP(^)&JtgYyM!GC~xCrUcKd>Im znUa3y!uyi24N$TpVUuQ>0=F68C7RH;qi=FzT)OZVT)oe$%&5u%~9J_=8 zJ=L-L1oPZWp$F}0UAe{;K6!n8eIB1b-JBUH)5S(bN2mQACR0RT826B4)N2RoXliOA z?I^coFk6)ea~0`Di`1YB+ ze{CfD6+ZmGxzy7~+wRJHO}D_=8+DcWxXkRNQWc$9gO&bzIyxtUt6!A4k2|chSprNE zijJja3Y?zk#; zmAS@6Q*liPpcPT`0fgG9KMc{kD|{CUF-$MCj`*|d*RDnA=NhLmXx$dqEzHbvfw(R$ zdHLUd{rwg}At7ByN1QMP+fD0?lX5KdnlVP}2p^2>p8-NcEVQcMizb+D$id-s&CIAr zr{uZAfdwE#*`x>Q1qecSLu#NAnQq`ZeE6ESb~xaUVoYLUF06#BjKbL2VONR7aVvxM6q6f=wRw41!e$~&ezqkgTqzq)jzxa=FLC49{~evdF2dw%Ob=Wp zdbVpyPriSBzwgz}`{scw)|g3kiz5rZ@$vBpg0dU<5mtEgq%y(?*E4ULVR zrgIaI5o2 zPEAeeSX+-itpoKWy+d8c);7b8zOeFf&PT`L!(8c!h_`&r0%+ARu4@3~FmJBdq=Sz2 zq*eTb$ooGEXugb0eVphM5kCzw2bZT1ZAXTy>pq84W_aLY3iMV#<1RScH%?exMCIV| zUG55dV)0&&4i|5r+mW)8C9r+>?=q&%7c zqsi*$I%PZ}qjWL4E!Nk+3_U&yN%>II8s}gR$+*HUVWEm_-+uC&%j;saD+?G{UcH{R&{R^2vfj4+M@(?7h%LR!vcK20LkvJ1(e?) z%4ota!(rw@N=DA@<4>5Q&tRJ;hJ$%X+3xUsZ+SuAy~~Sm=JfQ-AS=Zw2a$lxq(j3v zAf&H`jS;iJ!#JB&R%V1P^kX=P7#FIUlsI>Os0r=t>PoK;=5QQuO2noDtph${zcAJa zL?fnoTvtWqeuX&!L`B}o7!zmoAE3a9+wcsrmr4$7 ze;^5XNy8Io*JpmkU?Q?Mff!f&n;4956YyYKRz{MJo?a(k^$@7IALe`-fb5;=1!*Kn z&htu2N(#t90wc~PCarcJUu2k41}C@Jcmjvw;*)f(#B3lyb~kRE9|Cz$wAb;!1zod2 zUO*;sH3U^KV-m^NK$Zu&o2>#I#291@Q3n3=F9w}`10pcB#Y^*1ckTb5B^enRnanqD z-bBTLz!92621XO~IXEcOyJF-RGFos1eO7_V!F;j8*Vh*yi1zX_Su~6<2!MG7TNDq?O;TLkO>+#Y<%z)WJ16m> zo0Zoqu=S})J|DKRfCF?XuT?7}(l9X7geS6y_N)HKAP+JH z8-9y$L;fh6)_viFrZm-~o>g64{jhFC04`Q#6}g1+`Lz>uDW&ync*x}xL@%kpy0bxi zX}1hlOxF(hurGGH2%0!ZNrT|f)+Td-WVZE+KkD6PG1BLkvm~Z7)t_F8aHhceau|qC z9}Mm7pBt$=^-3L@DemOVhFAZ~JVcQlR~M^HT?RcVJ3G4#ERebxsb9H2|2%mS7!R1M z;bbH&SE!DTj{1SJ9G#fBGzz?-bSyoUxZz(ilZpxN0S^GywBn^*^kH5JP)Y_h7IZmi z9lVv?Z0AabJ`1GqN+2ip%kEi1;c#P@E&sqkYWhtE68_MaooWtTlAs1GGzze+suJ(O z?eNbYm2$kds=@pKOgI4XG@DU?&;Qn82aVdzRiNYAElu|VRRiO7YIc^pB#=H>DJUQi zQ9ax;Q086~w9bxV#aL$na05|kXl^z*{!L4sWu1-#ci`unu*P3dxX`{GzZ3=QB(8IL za~1H~OcKc)M6;=x89jUZERZ0y5@gQWuwLK)s>owf2OtRSN5qnak7_t~kz?`g-j^{k z2aAe}1^0&7O?6~}YAWZylw&3#F8&fugF*6YM+Sa{3+c<3{s73px{+WgLW)pDMa7`Q zTY)>4GKxqs?UNO6jDV+jyl2R7!lS|-Jwgokd*`dTDNRlK$G^EM!^OSu)Kw+{tk}P` zvSSs-M-Opx>w~uWUwQ<1V)h;$*PTn`tp{ejgh15w)qshog#BhDI$^<0^1rQ~qe`3r z4duXRwD^ynyM0yoU_r@9U3dl&@-}ycVwO;bSFb)#NZ{JFdw2RHj@`Q(M&=&*J_2#% zFW36`D|2Sk4QFROPz#t2ti$;t(V)j><>$M-vqW^LDZZ?Nf_D7!N&xss0#ab)N2qBMbrY6cm`uN|qb?j;$9>AuKQeqvua*Q0Donv@&%`1&V4~n#bMO)-EcMD|m$xp-8 zeP!+$yaGZ(>Tt`x3f}@49YAF0XlXePv+32-R?2jkF|?AEa_`;S0B-&ij7Zqj5;AM1 z9A*`M574nu{8yl`rluSc6?}ETiU2nbd1~-|S$tveSJt-Vz`XqI>?=fXdRMw;eBqsu z%Lh0)9fwGmI0@S@vjS(T90Un+P0Ms(gN;jFUw}yjlm2uZOgk6`&Y~!YG!iDyi`G_V z6%H;WEt`(OuA@>_|Nj$N`7JlT2meOWf9bX?*X|Wopr^NYxCz;tZXO*K1=6$K>_Ax` zIVh+hkdlws_ix@@fP{k3H(yWmFd|DH*~8r3HEA0=A>I^@TFC zK)D?>5=Y<4!S#|~+CaWiF^EWMSrMp?#IxSM{mIYIuMo^ue_`-;ARRR9>hsR>&G4%! z8cnwra1g|}Q|PHLc^AFo-@IBTnu@{}JgB~jI8MW^hA>#bRG5tNV=D%`8ReEV{M@sA}W9?1xiOA_c6sDJvNYl=tLAt zH(jF@=!;eN8VOVs3~+U_T}5W*gibqD`{Ai~&m`Q+OedXBihx*hDiTp+O?@@^Bjqrw zg*NY=zR1bR!L6DCI9$4PNiJ}$G%q$ojMLH6+K?br?Pqqli#DI9Z!qS7kF_=eDdwFu zW`LF}_`mZN@RlG;mEHj+Os^Dqoi7rkXxidrdu+g9MTG*;jkHJLn>ej)BcZi2A}w~t zh7^^H=X6Kkti=sUoT)66s(;ea-a|~!&qoje!by3Kh`$0(8F?tE z^9voeZwuUV6$?A3r>mRt@Zt8xWMzx~VkbmH1tXFdO^Frj&M_+J7#I*n8&~=j9TYK6 z1=Y@B?n4a{@)@oUE^;5&LhI=3$AIiy=(}J8kVF|@0FC53$m&7w!!Cw-SEwLUixPV- z90$x8B%Jn59173*)ME}|&=3KtU|9bVqHi_ccPcAmv1${?1We{SAa*OwM}%jlrCERA zR|hF9hWKQhV#aypTP#{(HUj)1ZjIQ}5(-wnBhLOR7?r9CBD z{u*a8h-r^n1_p9K!$cB{9Z+vT@)qCh5KHYyQ1sVd9tt8Ob#~Ob^gEuFx2mG$_N5`* zz9086m()04>1UY-%w58Oh}H=Su_*9-zzLdy{<8G(lUnwCy2sAS$y~Dy)*3o@+(O$%)7y@j%a1?c&oN@!1{X<&s zh54I&Xr8kJvWkl>DoC0J1O!3&OK5(yv{>^ivcYJYXVEt(39z63Mg;!!l5BkIrEQY$ zW#k0|KUec)7Zm6sTb<0FOcsQEgcOI;=jV`AqT2!e@*8_*i^xHf)jVw1BSbASV9{jA zsiYR-AOjzDr)$b{ok?ED$J zapO!TnEdH^3yYs&Aj+)c2{P?3ap~3$oZp9;3$^F~oL%g(4DXfVCm28|5Gb41bkEr_ zHpM93ni4Zvr_9cilb7!t9K_ouWAud#@Clg$8kuulv}gnKYY+_rzK*0CT>?gq`VBwe z^k`OGPVIPBXm0TM!6R{qm(kcvdPxxXv13J`M0O4gB-G-?+aV1gjP87_&&h+mEh?(#LI$YZ46@2YZ=lfKQU=iF|)JI>R6x&tcHByf^J&l z`|UTj9{xM7TS9G+C5MQCNChAbM^iK8!(BN{mKVA(x3C1hs{0)3jyP-bQEPO2)eMnKw> zX97d_Pv-4%RERA&_owcm3&9(+eP3^m0wYk3SA1B}HqKBu9SR&!RC6zlaoPH3i;b@x z|D04}iaj@d??pr!Mn)Fy9TJ*=@C{T}JA20&P-*~Am&6|I7;r9u_`;EL4MK78Iq){7 zT3)Dd&RFK^AoF$0NVqZ{yfLSN(rdxHkF~)Q+K-NnWv8dtnn&ibo#IN(!JsezYadc2 zNIX(bE<2bf&=9IL5+0ok9v$ukNlT4hW$T9h$ja3ch9EP>Xi?#?hh(gqTq4$xApH>a z6M7N~sk=2rm2sbE?Km~pDi!#!BV-hTAZV;C1RpX1#}yWUJxgb6VLj@Oc~3yH1PRc> zFcL?kC|EH2x)(ciwUKDy7sKPEYFH+)ese-B?;n zr|dhWH9@5ZmLN91fNuK4GutG@>w*6~QPSUN%nAc)MGa)m2w2j6>f|NQoH>qOiu>fU zAeFrG&K7cSeh4Q(a>tluYQ|SQ0aqIv8(|^lOq=?^Jqaa!_GYj7CMse2dKkoS*nv=< z2!|)Fp2{KqE6?vhp-U8cbFfH0l_ zlR}|*Qa!uVa)x$@RXhhu4e>$f3~lRc&Fd3D$d~lN3IQ`#SGS^$~f92n1x09F!;+(xfBLhJ=xB28~c z`{gWcOapM1u4VE%hBg{&QCwnR?puRj0Ul>hb)$%M-P`n44wxe*aBF=K8jcW~xX4VA z8fYR4yzEc`ejz;qA%rFN}=D$aOJV2Mxjr1Bb2< zUY0S1BOrL1nU$3X0;+XUNr?#@*)Bwn#65prWekppndQ76B&2JBK@_a6t^&Zsk{0=J z1VS^rYicy7k?K9jYOdldI$YC~w9Bo%yntHobnxjJn1F9(xaEY5X7VlCV#wClL zCkFZW^eFwM)3kt|WP`k-KFoDe2+9U4uiY#+Y4{pjJAkD0h+~}w7Dena7j`Y8%zCmI z-6p^$gL7gn0%b9pwoisQ?&j!a(*wrkyC4pi16?1Nn))+LYqkw+D{*^pi4C1g`e9@& zm@TJ?(!f1X(=C}i5MGyyU~tKRTPOrHa!M*^oVp-$+Z)Ac71>Rj%t+>5@P-5T*xD$z z_eVYCa>2|pv~5a2SKr20&UN9)kf$UwlLJ@nnOyw2-JMu#BeFZq s3)n@)=-iiJ?!TwU{HOo_<=y&*=U8b|L;cuBcq0`0lGa81g&W`hA8bI}-T(jq From 3796258a1687b593d9b8d88effb87dd52a475cd6 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Sat, 5 Feb 2022 17:03:59 +0900 Subject: [PATCH 14/23] remove pyc --- .../__pycache__/elpd_df.cpython-38.pyc | Bin 2476 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 search_algorithms/__pycache__/elpd_df.cpython-38.pyc diff --git a/search_algorithms/__pycache__/elpd_df.cpython-38.pyc b/search_algorithms/__pycache__/elpd_df.cpython-38.pyc deleted file mode 100644 index f2764fb550bb2a4f6df626fbdf8ba5d54852c574..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2476 zcmZ`)OK%)S5bmDWKD>U$BsKzNfGDuRvf;u394k_s@K7*O9+y?nYP{X+9eZBso^`y| z_=0moxp9k;>|_24KVc4BIOWU*<-}J#wuu*IM%^=AU0q%M)mPPzCnuW>&!5ge!LN13 z{-VL#PYr`xX!;Lyf(ah64#(Gutk~|@oMd+7#BRrpy^a^xI<>gosdFY9(nV|ZtwkoB zBTKl#d+v0al5bDQhNwYO7Y#^S;-d(= zy-??+`)HU&vUDFOVVbn9vK6O7MqQn&Fxe_SCG(+5dYmksJ}$lv+HEwwj4oscykN%w z?u4(ipWi>Q3JW9agg>zk>;tE83SQW&?CDfNZrS_aG<`g`(_~kw+z^f^qB(t*aa*t2xyVP@(F$lwhbIX zCu_%6mvb)cL;efUh1>$h_6a+7PS`YC#CzYq;l!i9!4w;V?@TG2u0J``3O-q7ee5L6oR6paZYJm)>oq--W0*whZy~h#`dq`uoSY2Ys;%q+6yAR zhM!8a-2}S^K_<(_oxPsSh&E-dmqx=l(WMvaWEe%I9VDYp?KJYGvz>;Cx(&Rn&ONStjN5K$iGJCvXh-G@@zosW zDqc}eo;??G^Z}#=*DU}GCLmXBy#Qu$Vpa}T75vbmEZw%%T494n)++H0yb`u>URVzq zq-@)Th9|7PnMF9<{of2FzVG|_K%SzFH))f zB+Y$eVae|)d8UudRijGxfhkAYm&DbAXTuzGJByPZtUs1;j3#MOL^wxOon2BeV5)1#&BQsF)zA*8ZkJ zu{V&t9mKJZ1m6mG_;ZNu;}^+F!sh>Ed3OdV33Eb@9AA+H+fgq`(Y&M z#u{^{tm|M`lIFKi(`HMqR`CCdE4n%VR}v}Mw@H_3T#gc0FjKkt2KKrESy@M32r{X@ zz;J9s^(ACKpq=eSW4OM7 Date: Sun, 13 Feb 2022 11:49:59 +0900 Subject: [PATCH 15/23] Upload full birthday DF --- .../bayesian_probabilistic_search.py | 10 +- search_algorithms/birthday_df.csv | 104 +++++++++--------- 2 files changed, 59 insertions(+), 55 deletions(-) diff --git a/search_algorithms/bayesian_probabilistic_search.py b/search_algorithms/bayesian_probabilistic_search.py index 3b39dc9..74b5372 100644 --- a/search_algorithms/bayesian_probabilistic_search.py +++ b/search_algorithms/bayesian_probabilistic_search.py @@ -72,6 +72,7 @@ def plot_signatures(df): + def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_iterations=10): # model df must contain all the models model_df = elpd_df.read_csv(model_df_path) @@ -87,7 +88,10 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera for iter in range(1, num_iterations + 1): print("-" * 20) print(f"iteration {iter}") - draw = model_df.sample(weights=model_df.probability) + model_df = model_df.loc[model_df["elpd"] == np.nan] + #draw = model_df.sample(weights=model_df.probability) + draw = model_df.loc[model_df["elpd"] == np.nan] + model_df.loc[draw.index, "selected"] = True @@ -147,7 +151,7 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera previous_iteration_elpd = elpd previons_iteration_model_dict = model_dict - elpd_df.save_csv(model_df.drop(columns=["probability", "selected"]), "birthday_df_prob.csv") + elpd_df.save_csv(model_df.drop(columns=["probability", "selected"]), "birthday_df.csv") elpd_df.save_csv(model_df, "bayesian_update_results.csv") plot_signatures(model_df) @@ -157,5 +161,5 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera example_dir = pathlib.Path(__file__).resolve().parents[1].absolute().joinpath("examples") birthday_model_path = example_dir.joinpath("birthday/birthday.m.stan") birthday_data_path = example_dir.joinpath("birthday/births_usa_1969.json") - birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df_prob.csv") + birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df.csv") bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=20) diff --git a/search_algorithms/birthday_df.csv b/search_algorithms/birthday_df.csv index 659d7cf..b243dda 100644 --- a/search_algorithms/birthday_df.csv +++ b/search_algorithms/birthday_df.csv @@ -1,74 +1,74 @@ DayOfWeekTrend,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfYearTrend,HolidayTrend,LongTermTrend,Regression,SeasonalTrend,elpd,DayOfWeekWeights no,no,no,yes,no,no,glm,no,4578.51, -no,no,no,yes,no,no,glm,yes,, -no,no,no,yes,no,yes,glm,no,, +no,no,no,yes,no,no,glm,yes,2387.24057250632, +no,no,no,yes,no,yes,glm,no,7125.9942028813, no,no,no,yes,no,yes,glm,yes,, no,no,no,yes,yes,no,glm,no,, -no,no,no,yes,yes,no,glm,yes,, +no,no,no,yes,yes,no,glm,yes,2920.39632324059, no,no,no,yes,yes,yes,glm,no,, -no,no,no,yes,yes,yes,glm,yes,, -no,no,yes,yes,no,no,glm,no,, -no,no,yes,yes,no,no,glm,yes,, +no,no,no,yes,yes,yes,glm,yes,5311.27329352511, +no,no,yes,yes,no,no,glm,no,5442.89912071062, +no,no,yes,yes,no,no,glm,yes,4399.63544787709, no,no,yes,yes,no,yes,glm,no,, no,no,yes,yes,no,yes,glm,yes,, no,no,yes,yes,yes,no,glm,no,, -no,no,yes,yes,yes,no,glm,yes,, -no,no,yes,yes,yes,yes,glm,no,, -no,no,yes,yes,yes,yes,glm,yes,, +no,no,yes,yes,yes,no,glm,yes,5604.70609557286, +no,no,yes,yes,yes,yes,glm,no,7197.64047500906, +no,no,yes,yes,yes,yes,glm,yes,7162.06930181051, no,yes,no,yes,no,no,glm,no,, no,yes,no,yes,no,no,glm,yes,, no,yes,no,yes,no,yes,glm,no,7179.98, -no,yes,no,yes,no,yes,glm,yes,, -no,yes,no,yes,yes,no,glm,no,, -no,yes,no,yes,yes,no,glm,yes,, +no,yes,no,yes,no,yes,glm,yes,7379.32679859518, +no,yes,no,yes,yes,no,glm,no,5455.05228337242, +no,yes,no,yes,yes,no,glm,yes,5674.10498513126, no,yes,no,yes,yes,yes,glm,no,7228.88, no,yes,no,yes,yes,yes,glm,yes,, -no,yes,yes,yes,no,no,glm,no,, -no,yes,yes,yes,no,no,glm,yes,, -no,yes,yes,yes,no,yes,glm,no,, -no,yes,yes,yes,no,yes,glm,yes,, -no,yes,yes,yes,yes,no,glm,no,, +no,yes,yes,yes,no,no,glm,no,5232.3273276717, +no,yes,yes,yes,no,no,glm,yes,4406.46562859129, +no,yes,yes,yes,no,yes,glm,no,6703.35639574843, +no,yes,yes,yes,no,yes,glm,yes,6859.312500912, +no,yes,yes,yes,yes,no,glm,no,5257.83852126472, no,yes,yes,yes,yes,no,glm,yes,, no,yes,yes,yes,yes,yes,glm,no,5882.44, no,yes,yes,yes,yes,yes,glm,yes,5888.01, no,,,no,no,no,glm,no,5232.03, no,,,no,no,no,glm,yes,5585.64, no,,,no,no,yes,glm,no,6717.52, -no,,,no,no,yes,glm,yes,, +no,,,no,no,yes,glm,yes,7249.78141003667, no,,,no,yes,no,glm,no,5258.05, -no,,,no,yes,no,glm,yes,, -no,,,no,yes,yes,glm,no,, -no,,,no,yes,yes,glm,yes,, +no,,,no,yes,no,glm,yes,5522.14783298078, +no,,,no,yes,yes,glm,no,6760.55418752675, +no,,,no,yes,yes,glm,yes,7283.7324312413, yes,no,no,yes,no,no,glm,no,6540.03,uniform -yes,no,no,yes,no,no,glm,yes,,uniform +yes,no,no,yes,no,no,glm,yes,3344.3654692436,uniform yes,no,no,yes,no,yes,glm,no,13949.47,uniform -yes,no,no,yes,no,yes,glm,yes,,uniform -yes,no,no,yes,yes,no,glm,no,,uniform -yes,no,no,yes,yes,no,glm,yes,,uniform +yes,no,no,yes,no,yes,glm,yes,13950.2470267613,uniform +yes,no,no,yes,yes,no,glm,no,7489.57698442313,uniform +yes,no,no,yes,yes,no,glm,yes,4003.44946683213,uniform yes,no,no,yes,yes,yes,glm,no,,uniform yes,no,no,yes,yes,yes,glm,yes,,uniform yes,no,yes,yes,no,no,glm,no,7457.22,uniform yes,no,yes,yes,no,no,glm,yes,,uniform yes,no,yes,yes,no,yes,glm,no,13956.11,uniform yes,no,yes,yes,no,yes,glm,yes,,uniform -yes,no,yes,yes,yes,no,glm,no,,uniform -yes,no,yes,yes,yes,no,glm,yes,,uniform -yes,no,yes,yes,yes,yes,glm,no,,uniform +yes,no,yes,yes,yes,no,glm,no,7250.42854438999,uniform +yes,no,yes,yes,yes,no,glm,yes,7323.31146498067,uniform +yes,no,yes,yes,yes,yes,glm,no,8838.90163837759,uniform yes,no,yes,yes,yes,yes,glm,yes,,uniform yes,yes,no,yes,no,no,glm,no,7812.98,uniform -yes,yes,no,yes,no,no,glm,yes,,uniform +yes,yes,no,yes,no,no,glm,yes,7588.18829799233,uniform yes,yes,no,yes,no,yes,glm,no,13979.34,uniform yes,yes,no,yes,no,yes,glm,yes,14076.06,uniform yes,yes,no,yes,yes,no,glm,no,,uniform -yes,yes,no,yes,yes,no,glm,yes,,uniform +yes,yes,no,yes,yes,no,glm,yes,7337.73019675242,uniform yes,yes,no,yes,yes,yes,glm,no,14526.21,uniform -yes,yes,no,yes,yes,yes,glm,yes,,uniform +yes,yes,no,yes,yes,yes,glm,yes,9900.42594930185,uniform yes,yes,yes,yes,no,no,glm,no,6935.09,uniform -yes,yes,yes,yes,no,no,glm,yes,,uniform +yes,yes,yes,yes,no,no,glm,yes,7187.80978134713,uniform yes,yes,yes,yes,no,yes,glm,no,13959.7,uniform -yes,yes,yes,yes,no,yes,glm,yes,,uniform -yes,yes,yes,yes,yes,no,glm,no,,uniform -yes,yes,yes,yes,yes,no,glm,yes,,uniform +yes,yes,yes,yes,no,yes,glm,yes,14029.7491348079,uniform +yes,yes,yes,yes,yes,no,glm,no,6851.6486517034,uniform +yes,yes,yes,yes,yes,no,glm,yes,7360.56737317448,uniform yes,yes,yes,yes,yes,yes,glm,no,12213.75,uniform yes,yes,yes,yes,yes,yes,glm,yes,14608.01,uniform yes,,,no,no,no,glm,no,6910.55,uniform @@ -76,23 +76,23 @@ yes,,,no,no,no,glm,yes,7522.39,uniform yes,,,no,no,yes,glm,no,10947.29,uniform yes,,,no,no,yes,glm,yes,13127.09,uniform yes,,,no,yes,no,glm,no,6818.07,uniform -yes,,,no,yes,no,glm,yes,,uniform +yes,,,no,yes,no,glm,yes,7332.93411844658,uniform yes,,,no,yes,yes,glm,no,11204.72,uniform -yes,,,no,yes,yes,glm,yes,,uniform -yes,no,no,yes,no,no,glm,no,,weighted +yes,,,no,yes,yes,glm,yes,13576.7239549785,uniform +yes,no,no,yes,no,no,glm,no,2975.31032486582,weighted yes,no,no,yes,no,no,glm,yes,,weighted yes,no,no,yes,no,yes,glm,no,14599.08,weighted -yes,no,no,yes,no,yes,glm,yes,,weighted -yes,no,no,yes,yes,no,glm,no,,weighted -yes,no,no,yes,yes,no,glm,yes,,weighted +yes,no,no,yes,no,yes,glm,yes,14603.4590747121,weighted +yes,no,no,yes,yes,no,glm,no,2464.57690507132,weighted +yes,no,no,yes,yes,no,glm,yes,3250.81680876078,weighted yes,no,no,yes,yes,yes,glm,no,14755.47,weighted yes,no,no,yes,yes,yes,glm,yes,,weighted yes,no,yes,yes,no,no,glm,no,,weighted -yes,no,yes,yes,no,no,glm,yes,,weighted -yes,no,yes,yes,no,yes,glm,no,,weighted -yes,no,yes,yes,no,yes,glm,yes,,weighted -yes,no,yes,yes,yes,no,glm,no,,weighted -yes,no,yes,yes,yes,no,glm,yes,,weighted +yes,no,yes,yes,no,no,glm,yes,2590.62592762889,weighted +yes,no,yes,yes,no,yes,glm,no,9675.48221267936,weighted +yes,no,yes,yes,no,yes,glm,yes,8293.60265791466,weighted +yes,no,yes,yes,yes,no,glm,no,1391.68730130601,weighted +yes,no,yes,yes,yes,no,glm,yes,1581.56156941945,weighted yes,no,yes,yes,yes,yes,glm,no,14929.74,weighted yes,no,yes,yes,yes,yes,glm,yes,15301.54,weighted yes,yes,no,yes,no,no,glm,no,4232.97,weighted @@ -100,11 +100,11 @@ yes,yes,no,yes,no,no,glm,yes,,weighted yes,yes,no,yes,no,yes,glm,no,14620.97,weighted yes,yes,no,yes,no,yes,glm,yes,8271.31,weighted yes,yes,no,yes,yes,no,glm,no,3283.42,weighted -yes,yes,no,yes,yes,no,glm,yes,,weighted +yes,yes,no,yes,yes,no,glm,yes,1391.03697537675,weighted yes,yes,no,yes,yes,yes,glm,no,15299.22,weighted yes,yes,no,yes,yes,yes,glm,yes,15071.18,weighted -yes,yes,yes,yes,no,no,glm,no,,weighted -yes,yes,yes,yes,no,no,glm,yes,,weighted +yes,yes,yes,yes,no,no,glm,no,3345.78672775113,weighted +yes,yes,yes,yes,no,no,glm,yes,3139.79714994284,weighted yes,yes,yes,yes,no,yes,glm,no,14598.53,weighted yes,yes,yes,yes,no,yes,glm,yes,11698.06,weighted yes,yes,yes,yes,yes,no,glm,no,1606.94,weighted @@ -114,8 +114,8 @@ yes,yes,yes,yes,yes,yes,glm,yes,15384.86,weighted yes,,,no,no,no,glm,no,2534.39,weighted yes,,,no,no,no,glm,yes,,weighted yes,,,no,no,yes,glm,no,9539.07,weighted -yes,,,no,no,yes,glm,yes,,weighted -yes,,,no,yes,no,glm,no,,weighted -yes,,,no,yes,no,glm,yes,,weighted +yes,,,no,no,yes,glm,yes,13585.3897712745,weighted +yes,,,no,yes,no,glm,no,2900.40294535745,weighted +yes,,,no,yes,no,glm,yes,1438.26047055841,weighted yes,,,no,yes,yes,glm,no,11466.83,weighted yes,,,no,yes,yes,glm,yes,14109.15,weighted From ef6216165666d6b8ef7def5e8e14939285553096 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Sun, 13 Feb 2022 11:50:38 +0900 Subject: [PATCH 16/23] Remove redundant csv --- search_algorithms/birthday_df_prob.csv | 121 ------------------------- 1 file changed, 121 deletions(-) delete mode 100644 search_algorithms/birthday_df_prob.csv diff --git a/search_algorithms/birthday_df_prob.csv b/search_algorithms/birthday_df_prob.csv deleted file mode 100644 index b243dda..0000000 --- a/search_algorithms/birthday_df_prob.csv +++ /dev/null @@ -1,121 +0,0 @@ -DayOfWeekTrend,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfYearTrend,HolidayTrend,LongTermTrend,Regression,SeasonalTrend,elpd,DayOfWeekWeights -no,no,no,yes,no,no,glm,no,4578.51, -no,no,no,yes,no,no,glm,yes,2387.24057250632, -no,no,no,yes,no,yes,glm,no,7125.9942028813, -no,no,no,yes,no,yes,glm,yes,, -no,no,no,yes,yes,no,glm,no,, -no,no,no,yes,yes,no,glm,yes,2920.39632324059, -no,no,no,yes,yes,yes,glm,no,, -no,no,no,yes,yes,yes,glm,yes,5311.27329352511, -no,no,yes,yes,no,no,glm,no,5442.89912071062, -no,no,yes,yes,no,no,glm,yes,4399.63544787709, -no,no,yes,yes,no,yes,glm,no,, -no,no,yes,yes,no,yes,glm,yes,, -no,no,yes,yes,yes,no,glm,no,, -no,no,yes,yes,yes,no,glm,yes,5604.70609557286, -no,no,yes,yes,yes,yes,glm,no,7197.64047500906, -no,no,yes,yes,yes,yes,glm,yes,7162.06930181051, -no,yes,no,yes,no,no,glm,no,, -no,yes,no,yes,no,no,glm,yes,, -no,yes,no,yes,no,yes,glm,no,7179.98, -no,yes,no,yes,no,yes,glm,yes,7379.32679859518, -no,yes,no,yes,yes,no,glm,no,5455.05228337242, -no,yes,no,yes,yes,no,glm,yes,5674.10498513126, -no,yes,no,yes,yes,yes,glm,no,7228.88, -no,yes,no,yes,yes,yes,glm,yes,, -no,yes,yes,yes,no,no,glm,no,5232.3273276717, -no,yes,yes,yes,no,no,glm,yes,4406.46562859129, -no,yes,yes,yes,no,yes,glm,no,6703.35639574843, -no,yes,yes,yes,no,yes,glm,yes,6859.312500912, -no,yes,yes,yes,yes,no,glm,no,5257.83852126472, -no,yes,yes,yes,yes,no,glm,yes,, -no,yes,yes,yes,yes,yes,glm,no,5882.44, -no,yes,yes,yes,yes,yes,glm,yes,5888.01, -no,,,no,no,no,glm,no,5232.03, -no,,,no,no,no,glm,yes,5585.64, -no,,,no,no,yes,glm,no,6717.52, -no,,,no,no,yes,glm,yes,7249.78141003667, -no,,,no,yes,no,glm,no,5258.05, -no,,,no,yes,no,glm,yes,5522.14783298078, -no,,,no,yes,yes,glm,no,6760.55418752675, -no,,,no,yes,yes,glm,yes,7283.7324312413, -yes,no,no,yes,no,no,glm,no,6540.03,uniform -yes,no,no,yes,no,no,glm,yes,3344.3654692436,uniform -yes,no,no,yes,no,yes,glm,no,13949.47,uniform -yes,no,no,yes,no,yes,glm,yes,13950.2470267613,uniform -yes,no,no,yes,yes,no,glm,no,7489.57698442313,uniform -yes,no,no,yes,yes,no,glm,yes,4003.44946683213,uniform -yes,no,no,yes,yes,yes,glm,no,,uniform -yes,no,no,yes,yes,yes,glm,yes,,uniform -yes,no,yes,yes,no,no,glm,no,7457.22,uniform -yes,no,yes,yes,no,no,glm,yes,,uniform -yes,no,yes,yes,no,yes,glm,no,13956.11,uniform -yes,no,yes,yes,no,yes,glm,yes,,uniform -yes,no,yes,yes,yes,no,glm,no,7250.42854438999,uniform -yes,no,yes,yes,yes,no,glm,yes,7323.31146498067,uniform -yes,no,yes,yes,yes,yes,glm,no,8838.90163837759,uniform -yes,no,yes,yes,yes,yes,glm,yes,,uniform -yes,yes,no,yes,no,no,glm,no,7812.98,uniform -yes,yes,no,yes,no,no,glm,yes,7588.18829799233,uniform -yes,yes,no,yes,no,yes,glm,no,13979.34,uniform -yes,yes,no,yes,no,yes,glm,yes,14076.06,uniform -yes,yes,no,yes,yes,no,glm,no,,uniform -yes,yes,no,yes,yes,no,glm,yes,7337.73019675242,uniform -yes,yes,no,yes,yes,yes,glm,no,14526.21,uniform -yes,yes,no,yes,yes,yes,glm,yes,9900.42594930185,uniform -yes,yes,yes,yes,no,no,glm,no,6935.09,uniform -yes,yes,yes,yes,no,no,glm,yes,7187.80978134713,uniform -yes,yes,yes,yes,no,yes,glm,no,13959.7,uniform -yes,yes,yes,yes,no,yes,glm,yes,14029.7491348079,uniform -yes,yes,yes,yes,yes,no,glm,no,6851.6486517034,uniform -yes,yes,yes,yes,yes,no,glm,yes,7360.56737317448,uniform -yes,yes,yes,yes,yes,yes,glm,no,12213.75,uniform -yes,yes,yes,yes,yes,yes,glm,yes,14608.01,uniform -yes,,,no,no,no,glm,no,6910.55,uniform -yes,,,no,no,no,glm,yes,7522.39,uniform -yes,,,no,no,yes,glm,no,10947.29,uniform -yes,,,no,no,yes,glm,yes,13127.09,uniform -yes,,,no,yes,no,glm,no,6818.07,uniform -yes,,,no,yes,no,glm,yes,7332.93411844658,uniform -yes,,,no,yes,yes,glm,no,11204.72,uniform -yes,,,no,yes,yes,glm,yes,13576.7239549785,uniform -yes,no,no,yes,no,no,glm,no,2975.31032486582,weighted -yes,no,no,yes,no,no,glm,yes,,weighted -yes,no,no,yes,no,yes,glm,no,14599.08,weighted -yes,no,no,yes,no,yes,glm,yes,14603.4590747121,weighted -yes,no,no,yes,yes,no,glm,no,2464.57690507132,weighted -yes,no,no,yes,yes,no,glm,yes,3250.81680876078,weighted -yes,no,no,yes,yes,yes,glm,no,14755.47,weighted -yes,no,no,yes,yes,yes,glm,yes,,weighted -yes,no,yes,yes,no,no,glm,no,,weighted -yes,no,yes,yes,no,no,glm,yes,2590.62592762889,weighted -yes,no,yes,yes,no,yes,glm,no,9675.48221267936,weighted -yes,no,yes,yes,no,yes,glm,yes,8293.60265791466,weighted -yes,no,yes,yes,yes,no,glm,no,1391.68730130601,weighted -yes,no,yes,yes,yes,no,glm,yes,1581.56156941945,weighted -yes,no,yes,yes,yes,yes,glm,no,14929.74,weighted -yes,no,yes,yes,yes,yes,glm,yes,15301.54,weighted -yes,yes,no,yes,no,no,glm,no,4232.97,weighted -yes,yes,no,yes,no,no,glm,yes,,weighted -yes,yes,no,yes,no,yes,glm,no,14620.97,weighted -yes,yes,no,yes,no,yes,glm,yes,8271.31,weighted -yes,yes,no,yes,yes,no,glm,no,3283.42,weighted -yes,yes,no,yes,yes,no,glm,yes,1391.03697537675,weighted -yes,yes,no,yes,yes,yes,glm,no,15299.22,weighted -yes,yes,no,yes,yes,yes,glm,yes,15071.18,weighted -yes,yes,yes,yes,no,no,glm,no,3345.78672775113,weighted -yes,yes,yes,yes,no,no,glm,yes,3139.79714994284,weighted -yes,yes,yes,yes,no,yes,glm,no,14598.53,weighted -yes,yes,yes,yes,no,yes,glm,yes,11698.06,weighted -yes,yes,yes,yes,yes,no,glm,no,1606.94,weighted -yes,yes,yes,yes,yes,no,glm,yes,2161.1,weighted -yes,yes,yes,yes,yes,yes,glm,no,15313.68,weighted -yes,yes,yes,yes,yes,yes,glm,yes,15384.86,weighted -yes,,,no,no,no,glm,no,2534.39,weighted -yes,,,no,no,no,glm,yes,,weighted -yes,,,no,no,yes,glm,no,9539.07,weighted -yes,,,no,no,yes,glm,yes,13585.3897712745,weighted -yes,,,no,yes,no,glm,no,2900.40294535745,weighted -yes,,,no,yes,no,glm,yes,1438.26047055841,weighted -yes,,,no,yes,yes,glm,no,11466.83,weighted -yes,,,no,yes,yes,glm,yes,14109.15,weighted From 827e5af8875d579ca7328a46fad690e8634155c2 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Sun, 13 Feb 2022 12:38:58 +0900 Subject: [PATCH 17/23] full elpd for birthday --- search_algorithms/birthday_df.csv | 42 +++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/search_algorithms/birthday_df.csv b/search_algorithms/birthday_df.csv index b243dda..424562f 100644 --- a/search_algorithms/birthday_df.csv +++ b/search_algorithms/birthday_df.csv @@ -2,33 +2,33 @@ DayOfWeekTrend,DayOfYearHeirarchicalVariance,DayOfYearNormalVariance,DayOfYearTr no,no,no,yes,no,no,glm,no,4578.51, no,no,no,yes,no,no,glm,yes,2387.24057250632, no,no,no,yes,no,yes,glm,no,7125.9942028813, -no,no,no,yes,no,yes,glm,yes,, -no,no,no,yes,yes,no,glm,no,, +no,no,no,yes,no,yes,glm,yes,7130.04565498837, +no,no,no,yes,yes,no,glm,no,4814.53621096012, no,no,no,yes,yes,no,glm,yes,2920.39632324059, -no,no,no,yes,yes,yes,glm,no,, +no,no,no,yes,yes,yes,glm,no,7160.41136143656, no,no,no,yes,yes,yes,glm,yes,5311.27329352511, no,no,yes,yes,no,no,glm,no,5442.89912071062, no,no,yes,yes,no,no,glm,yes,4399.63544787709, -no,no,yes,yes,no,yes,glm,no,, -no,no,yes,yes,no,yes,glm,yes,, -no,no,yes,yes,yes,no,glm,no,, +no,no,yes,yes,no,yes,glm,no,7172.32354113026, +no,no,yes,yes,no,yes,glm,yes,5077.25395483504, +no,no,yes,yes,yes,no,glm,no,5462.1771392874, no,no,yes,yes,yes,no,glm,yes,5604.70609557286, no,no,yes,yes,yes,yes,glm,no,7197.64047500906, no,no,yes,yes,yes,yes,glm,yes,7162.06930181051, -no,yes,no,yes,no,no,glm,no,, -no,yes,no,yes,no,no,glm,yes,, +no,yes,no,yes,no,no,glm,no,5459.36244870143, +no,yes,no,yes,no,no,glm,yes,5655.77885492243, no,yes,no,yes,no,yes,glm,no,7179.98, no,yes,no,yes,no,yes,glm,yes,7379.32679859518, no,yes,no,yes,yes,no,glm,no,5455.05228337242, no,yes,no,yes,yes,no,glm,yes,5674.10498513126, no,yes,no,yes,yes,yes,glm,no,7228.88, -no,yes,no,yes,yes,yes,glm,yes,, +no,yes,no,yes,yes,yes,glm,yes,7415.97216378276, no,yes,yes,yes,no,no,glm,no,5232.3273276717, no,yes,yes,yes,no,no,glm,yes,4406.46562859129, no,yes,yes,yes,no,yes,glm,no,6703.35639574843, no,yes,yes,yes,no,yes,glm,yes,6859.312500912, no,yes,yes,yes,yes,no,glm,no,5257.83852126472, -no,yes,yes,yes,yes,no,glm,yes,, +no,yes,yes,yes,yes,no,glm,yes,5601.3917445384, no,yes,yes,yes,yes,yes,glm,no,5882.44, no,yes,yes,yes,yes,yes,glm,yes,5888.01, no,,,no,no,no,glm,no,5232.03, @@ -45,21 +45,21 @@ yes,no,no,yes,no,yes,glm,no,13949.47,uniform yes,no,no,yes,no,yes,glm,yes,13950.2470267613,uniform yes,no,no,yes,yes,no,glm,no,7489.57698442313,uniform yes,no,no,yes,yes,no,glm,yes,4003.44946683213,uniform -yes,no,no,yes,yes,yes,glm,no,,uniform -yes,no,no,yes,yes,yes,glm,yes,,uniform +yes,no,no,yes,yes,yes,glm,no,13842.3779653104,uniform +yes,no,no,yes,yes,yes,glm,yes,9187.51611159167,uniform yes,no,yes,yes,no,no,glm,no,7457.22,uniform -yes,no,yes,yes,no,no,glm,yes,,uniform +yes,no,yes,yes,no,no,glm,yes,7519.51964700884,uniform yes,no,yes,yes,no,yes,glm,no,13956.11,uniform -yes,no,yes,yes,no,yes,glm,yes,,uniform +yes,no,yes,yes,no,yes,glm,yes,9439.97693562743,uniform yes,no,yes,yes,yes,no,glm,no,7250.42854438999,uniform yes,no,yes,yes,yes,no,glm,yes,7323.31146498067,uniform yes,no,yes,yes,yes,yes,glm,no,8838.90163837759,uniform -yes,no,yes,yes,yes,yes,glm,yes,,uniform +yes,no,yes,yes,yes,yes,glm,yes,14525.9429937942,uniform yes,yes,no,yes,no,no,glm,no,7812.98,uniform yes,yes,no,yes,no,no,glm,yes,7588.18829799233,uniform yes,yes,no,yes,no,yes,glm,no,13979.34,uniform yes,yes,no,yes,no,yes,glm,yes,14076.06,uniform -yes,yes,no,yes,yes,no,glm,no,,uniform +yes,yes,no,yes,yes,no,glm,no,7702.93949007661,uniform yes,yes,no,yes,yes,no,glm,yes,7337.73019675242,uniform yes,yes,no,yes,yes,yes,glm,no,14526.21,uniform yes,yes,no,yes,yes,yes,glm,yes,9900.42594930185,uniform @@ -80,14 +80,14 @@ yes,,,no,yes,no,glm,yes,7332.93411844658,uniform yes,,,no,yes,yes,glm,no,11204.72,uniform yes,,,no,yes,yes,glm,yes,13576.7239549785,uniform yes,no,no,yes,no,no,glm,no,2975.31032486582,weighted -yes,no,no,yes,no,no,glm,yes,,weighted +yes,no,no,yes,no,no,glm,yes,-13704.5269219521,weighted yes,no,no,yes,no,yes,glm,no,14599.08,weighted yes,no,no,yes,no,yes,glm,yes,14603.4590747121,weighted yes,no,no,yes,yes,no,glm,no,2464.57690507132,weighted yes,no,no,yes,yes,no,glm,yes,3250.81680876078,weighted yes,no,no,yes,yes,yes,glm,no,14755.47,weighted -yes,no,no,yes,yes,yes,glm,yes,,weighted -yes,no,yes,yes,no,no,glm,no,,weighted +yes,no,no,yes,yes,yes,glm,yes,15322.3678699749,weighted +yes,no,yes,yes,no,no,glm,no,1563.68297823065,weighted yes,no,yes,yes,no,no,glm,yes,2590.62592762889,weighted yes,no,yes,yes,no,yes,glm,no,9675.48221267936,weighted yes,no,yes,yes,no,yes,glm,yes,8293.60265791466,weighted @@ -96,7 +96,7 @@ yes,no,yes,yes,yes,no,glm,yes,1581.56156941945,weighted yes,no,yes,yes,yes,yes,glm,no,14929.74,weighted yes,no,yes,yes,yes,yes,glm,yes,15301.54,weighted yes,yes,no,yes,no,no,glm,no,4232.97,weighted -yes,yes,no,yes,no,no,glm,yes,,weighted +yes,yes,no,yes,no,no,glm,yes,2397.22787212426,weighted yes,yes,no,yes,no,yes,glm,no,14620.97,weighted yes,yes,no,yes,no,yes,glm,yes,8271.31,weighted yes,yes,no,yes,yes,no,glm,no,3283.42,weighted @@ -112,7 +112,7 @@ yes,yes,yes,yes,yes,no,glm,yes,2161.1,weighted yes,yes,yes,yes,yes,yes,glm,no,15313.68,weighted yes,yes,yes,yes,yes,yes,glm,yes,15384.86,weighted yes,,,no,no,no,glm,no,2534.39,weighted -yes,,,no,no,no,glm,yes,,weighted +yes,,,no,no,no,glm,yes,2169.49585534679,weighted yes,,,no,no,yes,glm,no,9539.07,weighted yes,,,no,no,yes,glm,yes,13585.3897712745,weighted yes,,,no,yes,no,glm,no,2900.40294535745,weighted From 9f514a495c8d5c62c725ecd13402fc14d209e909 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Tue, 15 Feb 2022 14:30:46 +0900 Subject: [PATCH 18/23] Add roach model, score search, and notebook template --- examples/roach/roach.csv | 263 ++++++++++++++++++ examples/roach/roach.json | 1 + examples/roach/roach.m.stan | 120 ++++++++ .../bayesian_probabilistic_search.py | 83 +++++- search_algorithms/elpd.R | 2 +- .../model_space_exploration.ipynb | 106 +++++++ 6 files changed, 561 insertions(+), 14 deletions(-) create mode 100644 examples/roach/roach.csv create mode 100644 examples/roach/roach.json create mode 100644 examples/roach/roach.m.stan create mode 100644 search_algorithms/model_space_exploration.ipynb diff --git a/examples/roach/roach.csv b/examples/roach/roach.csv new file mode 100644 index 0000000..db478db --- /dev/null +++ b/examples/roach/roach.csv @@ -0,0 +1,263 @@ +y,roach1,treatment,senior,exposure2 +153,308,1,0,0.8 +127,331.25,1,0,0.6 +7,1.67,1,0,1 +7,3,1,0,1 +0,2,1,0,1.14285714285714 +0,0,1,0,1 +73,70,1,0,0.8 +24,64.56,1,0,1.14285714285714 +2,1,0,0,1 +2,14,0,0,1.14285714285714 +0,138.25,0,0,1 +21,16,0,0,1 +0,97,0,0,1 +179,98,0,0,0.8 +136,44,0,0,1 +104,450,0,0,0.8 +2,36.67,0,0,0.8 +5,75,0,0,1 +1,2,0,0,1 +203,342.5,0,0,1 +32,22.5,0,0,1 +1,4,0,0,1 +135,94,0,0,1 +59,132.13,0,0,0.857142857142857 +29,80,0,0,1 +120,95,0,0,1 +44,34.13,0,0,0.8 +1,19,0,0,1 +2,25,0,0,1 +193,57.5,0,0,1 +13,3.11,0,0,1 +37,10.5,0,0,1 +2,47,0,0,1 +0,2.63,0,0,1 +3,204,0,0,1 +0,0,0,0,1 +0,0,0,0,1 +15,246.25,0,0,0.8 +11,76.22,0,0,1.57142857142857 +19,15,0,0,1.42857142857143 +0,14.88,0,0,1 +19,51,0,0,1 +4,15,0,0,1 +122,52.5,0,0,0.8 +48,13.75,0,0,1 +0,1.17,0,0,1 +0,1,0,0,1 +3,2,0,0,1 +0,0,0,0,1 +9,2,0,0,1 +0,2,0,0,1.14285714285714 +0,30,0,0,1.14285714285714 +0,7,0,0,1 +12,18,0,0,1 +0,2,0,0,0.8 +357,266,0,0,1 +11,174,1,0,0.914285714285714 +60,252,1,0,1 +0,0.88,1,0,1.28571428571429 +159,372.5,1,0,1 +50,42,1,0,1 +48,19,1,0,0.771428571428571 +178,263,1,0,1 +4,34,1,0,1.42857142857143 +6,1.75,1,0,1 +0,213.5,1,0,0.771428571428571 +33,13.13,1,0,1 +127,154,1,0,1 +4,21,1,0,1 +63,220,1,0,1 +88,38.32,1,0,1 +5,352.5,1,0,0.6 +0,7,1,0,1 +0,4,1,0,1 +62,59.5,1,0,1 +4,7.5,1,0,1 +150,112.88,1,0,1.28571428571429 +38,172,1,0,1 +0,13,1,0,1 +3,18,1,0,1 +1,0,1,0,1 +14,27,1,0,1 +77,148,1,0,1 +42,32,1,0,1.14285714285714 +21,28,1,0,1 +1,0,1,0,1 +45,28,1,0,1 +0,0,1,1,0.8 +0,14,1,1,0.8 +0,5,1,1,1 +0,0,1,1,1 +0,104,1,1,0.685714285714286 +183,27,1,1,1 +28,132,1,1,1 +49,258,1,1,1 +1,1,1,1,1 +0,2,1,1,1 +0,6,1,1,1 +3,3,1,1,0.8 +0,3,1,1,0.857142857142857 +0,0,1,1,1 +0,0,1,1,1 +0,0,1,1,1 +18,1.25,1,1,1 +0,0,1,1,1 +0,16,1,1,1 +5,68,1,1,0.4 +0,1,1,1,0.8 +19,18,1,1,1.14285714285714 +5,123.67,1,1,1 +0,2,1,1,0.8 +27,82.5,1,1,1 +0,0,1,1,0.2 +0,1.25,1,1,1 +77,171,1,1,1 +1,91.88,1,1,1 +3,5,1,1,1 +2,7,1,1,1 +0,0,1,1,1 +0,4,1,1,1 +22,53.75,1,1,1 +102,138.06,1,1,1.14285714285714 +0,1,1,1,1.14285714285714 +0,1.25,1,1,1 +0,28,1,1,0.8 +0,15,1,1,1 +0,0.88,1,1,0.8 +0,0,1,1,1 +0,33,1,1,1 +4,136.25,1,1,4.28571428571429 +12,127.5,1,1,0.8 +2,2,1,1,0.8 +0,2,1,1,0.8 +0,0,1,1,1 +1,3,1,1,1 +0,46,1,1,1.28571428571429 +40,68,1,1,1 +0,0,1,1,1 +1,49,1,1,1 +2,27.13,1,1,1 +27,45.5,1,0,1 +0,0,1,0,0.457142857142857 +2,4,1,0,1 +0,0,1,0,2.42857142857143 +0,1,1,0,1.42857142857143 +0,0,1,0,1 +0,0,1,0,1 +3,10.5,1,0,1.14285714285714 +1,0,1,0,1 +20,3,1,0,1.14285714285714 +0,0,1,0,1 +0,0,1,0,1 +0,0,1,0,1 +0,3.75,1,0,0.8 +0,0,1,0,1.14285714285714 +0,0,1,0,1 +0,0,0,0,0.8 +0,0.88,0,0,1.57142857142857 +0,2,1,0,1 +53,81,1,0,1.14285714285714 +69,31,1,0,2.28571428571429 +15,10,1,0,1 +0,10.23,1,0,0.571428571428571 +2,0,0,0,0.8 +4,13,1,0,0.857142857142857 +6,1,1,0,0.8 +8,33.25,1,0,0.857142857142857 +0,0,0,0,2.28571428571429 +0,53,0,0,0.8 +0,5,1,0,1 +18,157,1,0,0.857142857142857 +38,23.33,1,0,1 +0,6,1,0,1.02857142857143 +2,10,1,0,0.6 +18,100,1,0,0.857142857142857 +34,55,1,0,1 +1,0,1,0,0.8 +109,16.25,1,0,1 +5,7.78,1,0,1.48571428571429 +15,53,1,0,1 +0,2,1,0,1 +64,73,1,0,1 +0,0,1,0,0.8 +1,0,1,0,0.8 +0,0,1,0,1 +1,3.18,1,0,1 +3,5,0,0,0.857142857142857 +5,3,0,0,0.8 +7,0,0,0,1 +18,10,0,0,0.8 +1,1,0,0,1 +0,0,1,0,1 +0,1,1,0,0.8 +3,12,1,0,1 +3,17,1,0,1 +0,16,1,0,1 +19,2.5,1,0,1.42857142857143 +0,0,1,0,1.71428571428571 +8,21.88,1,0,0.771428571428571 +26,173,1,0,0.8 +50,111,1,0,1 +15,35,1,0,1.85714285714286 +0,0,1,0,1 +19,3,1,0,1 +5,2.1,1,0,1 +17,0,1,0,1 +121,49,1,0,1 +1,1.25,1,0,1 +0,1,1,0,0.8 +0,1,1,0,1 +0,0,1,0,1 +0,0,1,0,2.28571428571429 +4,54,1,0,0.8 +1,4.2,1,0,1.42857142857143 +14,51.25,1,0,1.02857142857143 +1,30,1,0,1.71428571428571 +25,196,0,0,1.14285714285714 +0,0,0,0,1.14285714285714 +14,2,0,0,1 +0,1.5,0,0,1 +59,96.25,0,0,1.14285714285714 +243,241.5,0,0,0.8 +80,140,0,0,1.42857142857143 +69,18,0,0,0.914285714285714 +14,3,0,0,1.14285714285714 +9,0.54,0,0,1.28571428571429 +38,82,0,0,0.6 +37,19,0,0,1 +48,18.75,0,0,0.8 +293,51,0,0,1.14285714285714 +7,0,0,0,1 +10,1,0,0,1.42857142857143 +19,0,0,0,1 +24,5.44,0,0,1 +91,0,0,0,1 +1,3,0,1,1 +0,0,0,1,1 +0,0,0,1,1 +0,0,0,1,1 +0,0,0,1,1 +148,28.75,0,1,1.14285714285714 +3,0,0,1,0.857142857142857 +26,3,0,1,1 +12,2,0,1,1 +77,135,0,1,1 +0,0,0,1,1 +7,68.25,0,1,1 +0,1,0,1,1 +1,0,0,1,1 +0,0,0,1,0.6 +17,0,0,1,1.02857142857143 +0,1,0,1,1 +7,1,0,1,1 +11,2.5,0,1,0.685714285714286 +6,51.25,0,1,0.8 +50,13.13,0,1,1 +1,0,0,1,0.8 +0,0,0,1,1.48571428571429 +0,0,0,1,1 +0,0,0,1,1 +171,0,0,1,1 +8,0,0,1,1 diff --git a/examples/roach/roach.json b/examples/roach/roach.json new file mode 100644 index 0000000..d69a631 --- /dev/null +++ b/examples/roach/roach.json @@ -0,0 +1 @@ +{"y":[153,127,7,7,0,0,73,24,2,2,0,21,0,179,136,104,2,5,1,203,32,1,135,59,29,120,44,1,2,193,13,37,2,0,3,0,0,15,11,19,0,19,4,122,48,0,0,3,0,9,0,0,0,12,0,357,11,60,0,159,50,48,178,4,6,0,33,127,4,63,88,5,0,0,62,4,150,38,0,3,1,14,77,42,21,1,45,0,0,0,0,0,183,28,49,1,0,0,3,0,0,0,0,18,0,0,5,0,19,5,0,27,0,0,77,1,3,2,0,0,22,102,0,0,0,0,0,0,0,4,12,2,0,0,1,0,40,0,1,2,27,0,2,0,0,0,0,3,1,20,0,0,0,0,0,0,0,0,0,53,69,15,0,2,4,6,8,0,0,0,18,38,0,2,18,34,1,109,5,15,0,64,0,1,0,1,3,5,7,18,1,0,0,3,3,0,19,0,8,26,50,15,0,19,5,17,121,1,0,0,0,0,4,1,14,1,25,0,14,0,59,243,80,69,14,9,38,37,48,293,7,10,19,24,91,1,0,0,0,0,148,3,26,12,77,0,7,0,1,0,17,0,7,11,6,50,1,0,0,0,171,8],"roach1":[308,331.25,1.67,3,2,0,70,64.56,1,14,138.25,16,97,98,44,450,36.67,75,2,342.5,22.5,4,94,132.13,80,95,34.13,19,25,57.5,3.11,10.5,47,2.63,204,0,0,246.25,76.22,15,14.88,51,15,52.5,13.75,1.17,1,2,0,2,2,30,7,18,2,266,174,252,0.88,372.5,42,19,263,34,1.75,213.5,13.13,154,21,220,38.32,352.5,7,4,59.5,7.5,112.88,172,13,18,0,27,148,32,28,0,28,0,14,5,0,104,27,132,258,1,2,6,3,3,0,0,0,1.25,0,16,68,1,18,123.67,2,82.5,0,1.25,171,91.88,5,7,0,4,53.75,138.06,1,1.25,28,15,0.88,0,33,136.25,127.5,2,2,0,3,46,68,0,49,27.13,45.5,0,4,0,1,0,0,10.5,0,3,0,0,0,3.75,0,0,0,0.88,2,81,31,10,10.23,0,13,1,33.25,0,53,5,157,23.33,6,10,100,55,0,16.25,7.78,53,2,73,0,0,0,3.18,5,3,0,10,1,0,1,12,17,16,2.5,0,21.88,173,111,35,0,3,2.1,0,49,1.25,1,1,0,0,54,4.2,51.25,30,196,0,2,1.5,96.25,241.5,140,18,3,0.54,82,19,18.75,51,0,1,0,5.44,0,3,0,0,0,0,28.75,0,3,2,135,0,68.25,1,0,0,0,1,1,2.5,51.25,13.13,0,0,0,0,0,0],"treatment":[1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,1,1,1,1,1,0,1,1,1,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"senior":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"exposure2":[0.8,0.6,1,1,1.1429,1,0.8,1.1429,1,1.1429,1,1,1,0.8,1,0.8,0.8,1,1,1,1,1,1,0.8571,1,1,0.8,1,1,1,1,1,1,1,1,1,1,0.8,1.5714,1.4286,1,1,1,0.8,1,1,1,1,1,1,1.1429,1.1429,1,1,0.8,1,0.9143,1,1.2857,1,1,0.7714,1,1.4286,1,0.7714,1,1,1,1,1,0.6,1,1,1,1,1.2857,1,1,1,1,1,1,1.1429,1,1,1,0.8,0.8,1,1,0.6857,1,1,1,1,1,1,0.8,0.8571,1,1,1,1,1,1,0.4,0.8,1.1429,1,0.8,1,0.2,1,1,1,1,1,1,1,1,1.1429,1.1429,1,0.8,1,0.8,1,1,4.2857,0.8,0.8,0.8,1,1,1.2857,1,1,1,1,1,0.4571,1,2.4286,1.4286,1,1,1.1429,1,1.1429,1,1,1,0.8,1.1429,1,0.8,1.5714,1,1.1429,2.2857,1,0.5714,0.8,0.8571,0.8,0.8571,2.2857,0.8,1,0.8571,1,1.0286,0.6,0.8571,1,0.8,1,1.4857,1,1,1,0.8,0.8,1,1,0.8571,0.8,1,0.8,1,1,0.8,1,1,1,1.4286,1.7143,0.7714,0.8,1,1.8571,1,1,1,1,1,1,0.8,1,1,2.2857,0.8,1.4286,1.0286,1.7143,1.1429,1.1429,1,1,1.1429,0.8,1.4286,0.9143,1.1429,1.2857,0.6,1,0.8,1.1429,1,1.4286,1,1,1,1,1,1,1,1,1.1429,0.8571,1,1,1,1,1,1,1,0.6,1.0286,1,1,0.6857,0.8,1,0.8,1.4857,1,1,1,1],"N":[262]} diff --git a/examples/roach/roach.m.stan b/examples/roach/roach.m.stan new file mode 100644 index 0000000..3b01584 --- /dev/null +++ b/examples/roach/roach.m.stan @@ -0,0 +1,120 @@ +data { + # y roach1 treatment senior exposure2 + int N; + int y[N]; + real roach1[N]; + int treatment[N]; + int senior[N]; + real exposure2[N]; +} + +transformed data { + vector[N] roach1_vec = to_vector(roach1); + #roach1_vec = roach1_vec ./ 100 + vector[N] treatment_vec = to_vector(treatment); + vector[N] senior_vec = to_vector(senior); + vector[N] exposure2_vec = to_vector(exposure2); +} + +model { + y ~ Regression(); +} + +module "glm" Regression() { + transformed parameters { + vector[N] offset = OffsetType(); + vector[N] mu = Roach() + Treatment() + Senior(); + } + Likelihood(); +} + +module "Poisson" Likelihood() { + generated quantities { + vector[N] log_lik; + for (n in 1:N) log_lik[n] = poisson_lpmf(y[n] | exp(mu + RandomEffects() + offset)); + } + y ~ poisson(exp(mu + RandomEffects() + offset)); +} + +module "NegBinomial" Likelihood() { + parameters { + real phi; + } + generated quantities { + vector[N] log_lik; + for (n in 1:N) log_lik[n] = neg_binomial_2_log_lpmf(y[n] | mu + offset, phi); + + } + PhiPrior(); + y ~ neg_binomial_2_log(exp(mu + offset), phi); +} + +module "normal" PhiPrior() { + phi ~ normal(0, 3); +} + +module "cauchy" PhiPrior() { + phi ~ cauchy(0, 3); +} + +module "log" OffsetType() { + return log(exposure2_vec); +} + +module "identity" OffsetType() { + return exposure2_vec; +} + +module "sqrt" Roach() { + parameters { + real roach_coeff; + } + return roach_coeff * sqrt(roach1_vec); +} + +module "identity" Roach() { + parameters { + real roach_coeff; + } + return roach_coeff * roach1_vec; +} + +module "no" Roach() { + return rep_vector(0, N); +} + +module "yes" Treatment() { + parameters { + real treatment_coeff; + } + return treatment_coeff * treatment_vec; +} + +module "no" Treatment() { + return rep_vector(0, N); +} + +module "yes" Senior() { + parameters { + real senior_coeff; + } + return senior_coeff * senior_vec; +} + +module "no" Senior() { + return rep_vector(0, N); +} + +module "yes" RandomEffects() { + parameters { + vector[N] random_effect; + } + model { + random_effect ~ normal(0, 5); + } + return random_effect; +} + +module "no" RandomEffects() { + return rep_vector(0, N) ; +} \ No newline at end of file diff --git a/search_algorithms/bayesian_probabilistic_search.py b/search_algorithms/bayesian_probabilistic_search.py index 74b5372..45be2c0 100644 --- a/search_algorithms/bayesian_probabilistic_search.py +++ b/search_algorithms/bayesian_probabilistic_search.py @@ -70,8 +70,65 @@ def plot_signatures(df): fig.tight_layout() fig.savefig(f"sigplot_{signature}.png") +################ +def bayesian_probabilistic_score_based_search(model_path, data_path, model_df, num_iterations=10): + model_count = model_df.shape[0] + model_df["probability"] = 1.0 / model_count + model_df["selected"] = False + # start bayesian probabilistic search + Implementation_score_dict = {} + # 이중 dictionary: signature가 1st key -> value: dictionary => 해당 dictionary는 {implementation: [ELPD sum, number of occurences]} + # 1st iteration에서만 dictionary의 key들을 채워넣을 것. 그 후 2번째 iteration부터는 value(dictionary)의 value들만 update + # Implementation_score_dict의 각 signature,implementation [0,0] 초기화 + for i in range(model_count): + for signature in list(model_df.drop(columns=["probability", "selected"]).columns): + implementation = model_df.loc[i, signature] + Implementation_score_dict[signature][implementation] = [0,0] + # iteration 시작 + for iteration in range(num_iterations): + # draw a model based on a probability distribution + draw = model_df.sample(weights=model_df.probability) + model_df.loc[draw.index, "selected"] = True + # compute the elpd value of the randomly drawn model + draw_string = elpd_df.row_to_string(draw.drop(columns=["probability", "selected"])) + model_dict = elpd_df.model_string_to_dict(draw_string) + draw_string = ",".join([f"{key}:{val}" for key, val in model_dict.items()]) + if not np.isnan(elpd_df.search_df(model_df, model_dict).elpd.values[0]): + elpd = elpd_df.search_df(model_df, model_dict).elpd.values[0] + print(f"using saved ELPD value {elpd}") + else: + print("calculating elpd value...") + elpd = calculate_elpd(model_path, draw_string, data_path) + # elpd = random.randint(500, 12000) + print(f"calculated ELPD value {elpd}, saving to df") + model_df = elpd_df.upsert_model(model_df, model_dict, elpd=elpd) + # Update the probability distribution (in the model space) + # (2) for each signature, compute its score (store it in a dictionary. ex) implementation: [ELPD score sum, number of occurences]) + for signature in list(model_df.drop(columns=["probability", "selected"]).columns): + implementation = model_df.loc[draw.index, signature] + Implementation_score_dict[signature][implementation] = [Implementation_score_dict[signature][implementation][0]+elpd,Implementation_score_dict[signature][implementation][1]+1] + # (3) calculate the score of each model + Score = [] + for i in range(model_count): + model_score = 0 + for signature in list(model_df.drop(columns=["probability", "selected"]).columns): + if Implementation_score_dict[signature][model_df.loc[i,signature]][1] > 0: + model_score +=Implementation_score_dict[signature][model_df.loc[i,signature]][0]/Implementation_score_dict[signature][model_df.loc[i,signature]][1] + else: + sum_scores_already_computed_implementations = sum([val[0]/val[1] if val[1] > 0 else 0 for val in Implementation_score_dict[signature].values]) + num_zero_occurence_implementations = sum([1 if val[1] == 0 else 0 for val in Implementation_score_dict[signature].values]) + model_score += sum_scores_already_computed_implementations/num_zero_occurence_implementations + Score.append(model_score) + # (4) calculate the sum of the scores of models + Score_sum = sum(Score) + # (5) update the probability of each model + for i in range(model_count): + model_df.loc[i, "probability"] = Score[i] / Score_sum + final_model = draw.drop(columns=["probability", "selected"]) + return final_model + def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_iterations=10): # model df must contain all the models @@ -88,10 +145,9 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera for iter in range(1, num_iterations + 1): print("-" * 20) print(f"iteration {iter}") - model_df = model_df.loc[model_df["elpd"] == np.nan] - #draw = model_df.sample(weights=model_df.probability) - draw = model_df.loc[model_df["elpd"] == np.nan] - + draw = model_df.sample(weights=model_df.probability) + #draw = model_df.loc[model_df["elpd"].isna()].sample() + print(draw) model_df.loc[draw.index, "selected"] = True @@ -141,17 +197,13 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera model_df["probability"] += update_arr - - - - print(model_df) plot_probabilities(model_df, iter) previous_iteration_elpd = elpd previons_iteration_model_dict = model_dict - elpd_df.save_csv(model_df.drop(columns=["probability", "selected"]), "birthday_df.csv") + elpd_df.save_csv(model_df.drop(columns=["probability", "selected"]), model_df_path) elpd_df.save_csv(model_df, "bayesian_update_results.csv") plot_signatures(model_df) @@ -159,7 +211,12 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera if __name__ == "__main__": example_dir = pathlib.Path(__file__).resolve().parents[1].absolute().joinpath("examples") - birthday_model_path = example_dir.joinpath("birthday/birthday.m.stan") - birthday_data_path = example_dir.joinpath("birthday/births_usa_1969.json") - birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df.csv") - bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=20) + # birthday_model_path = example_dir.joinpath("birthday/birthday.m.stan") + # birthday_data_path = example_dir.joinpath("birthday/births_usa_1969.json") + # birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df.csv") + # bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=20) + + roach_model_path = example_dir.joinpath("roach/roach.m.stan") + roach_data_path = example_dir.joinpath("roach/roach.json") + roach_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("roach_df.csv") + bayesian_probabilstic_search(roach_model_path, roach_data_path, roach_df_path, num_iterations=90) diff --git a/search_algorithms/elpd.R b/search_algorithms/elpd.R index 763644c..8b21e3b 100755 --- a/search_algorithms/elpd.R +++ b/search_algorithms/elpd.R @@ -27,7 +27,7 @@ print(cmdstan_version()) #' Model and fit model <- cmdstan_model(stan_file = stanfile, quiet=TRUE) -fit <- model$sample(data = standata, iter_warmup=100, iter_sampling=100, +fit <- model$sample(data = standata, iter_warmup=1000, iter_sampling=1000, chains=4, parallel_chains=4, seed=1) loo <- fit$loo() diff --git a/search_algorithms/model_space_exploration.ipynb b/search_algorithms/model_space_exploration.ipynb new file mode 100644 index 0000000..8626cb7 --- /dev/null +++ b/search_algorithms/model_space_exploration.ipynb @@ -0,0 +1,106 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Space Exploration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "- 문제상황 소개(모델 공간 탐색)\n", + "- 효율적 모델 공간 탐색을 위한 heuristic 개발\n", + "- Example Model Introduction : Birthday, Roach" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Greedy\n", + "benchmark reference, 특별한 설명 없이 단순한 설명 첨부" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combinatorial Structure-based search\n", + "### **Assumption verification**\n", + "#### Chain structure\n", + "\n", + "- 여러개의 Chain을 생성하여 elpd의 추세를 조사\n", + "\n", + "#### Dynamic Stepsize\n", + "\n", + "- Local neighborhood에서 linear approximation이 가능한지\n", + "- piecewise linear approximation을 하되, interval을 2, 3, ... 으로 다양한 value를 시도함\n", + "- approximation error를 조사하여 일관된 변화량을 가지고 있는지 확인함\n", + "\n", + "### Chain - Apex Predator Search\n", + "### Chain - search-dynamic step size\n", + "- stepsize algorithm의 결과 elpd추세 조사" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bayesian Approach\n", + "### Bayesian - Probabiliy based search\n", + "### Bayesian - Score based search" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing results\n", + "- Graphical comparison: bar graph \n", + "- Value-based comparison: ex) Objective function (ELPD value of the best model)\n", + "- Sensitivity Analysis(tuning hyperparameter of algorithms) algorithm-wise vs value-wise\n", + "\n", + "각 알고리즘을 n번 시행해서 평균 또는 best out of n\n", + "\n", + "|algorithm | Best Model ELPD| Iteration to reach best model|\n", + "|---|---|---|\n", + "|Greedy |.|.|\n", + "|Apex |.|.|\n", + "|Stepsize(p=0.7) |.|.|\n", + "|Stepsize(p=0.5) |.|.|\n", + "|... |.|.|\n", + "|Probability(K=20)|.|.|\n", + "|Probability(K=10)|.|.|\n", + "|Score(K=20) |.|.|\n", + "|Score(K=10) |.|.|\n", + "\n", + "- bayesian algorithms: 확률분포 비교(산포도), expectation, 추세선\n", + "- iteration-ELPD plot, iteration-expected_ELPD plot\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Limitations\n", + "\n", + "실험의 편의를 위해 ELPD를 한번만 계산했기에 ELPD계산 관련한 노이즈가 반영이 안되었음\n", + "\n", + "## Further Research Topics\n", + "## Conclusion" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 64587d8c693043b18973c0af800b01b0381b6b0f Mon Sep 17 00:00:00 2001 From: Dashadower Date: Wed, 16 Feb 2022 16:46:31 +0900 Subject: [PATCH 19/23] Upload roach csv --- search_algorithms/elpd.R | 2 +- search_algorithms/roach_df.csv | 95 ++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 search_algorithms/roach_df.csv diff --git a/search_algorithms/elpd.R b/search_algorithms/elpd.R index 8b21e3b..0cd88b2 100755 --- a/search_algorithms/elpd.R +++ b/search_algorithms/elpd.R @@ -27,7 +27,7 @@ print(cmdstan_version()) #' Model and fit model <- cmdstan_model(stan_file = stanfile, quiet=TRUE) -fit <- model$sample(data = standata, iter_warmup=1000, iter_sampling=1000, +fit <- model$sample(data = standata, iter_warmup=2000, iter_sampling=2000, chains=4, parallel_chains=4, seed=1) loo <- fit$loo() diff --git a/search_algorithms/roach_df.csv b/search_algorithms/roach_df.csv new file mode 100644 index 0000000..ba02a93 --- /dev/null +++ b/search_algorithms/roach_df.csv @@ -0,0 +1,95 @@ +Likelihood,OffsetType,PhiPrior,Regression,Roach,Senior,Treatment,elpd,RandomEffects +NegBinomial,identity,cauchy,glm,identity,no,no,-3.31010792100001e+31, +NegBinomial,identity,cauchy,glm,identity,no,yes,-7.16927294e+30, +NegBinomial,identity,cauchy,glm,identity,yes,no,-2.43405178e+30, +NegBinomial,identity,cauchy,glm,identity,yes,yes,-1.68181454e+30, +NegBinomial,identity,cauchy,glm,no,no,no,-340655.530325107, +NegBinomial,identity,cauchy,glm,no,no,yes,-341838.654549512, +NegBinomial,identity,cauchy,glm,no,yes,no,-350645.829702514, +NegBinomial,identity,cauchy,glm,no,yes,yes,-349907.987100588, +NegBinomial,identity,cauchy,glm,sqrt,no,no,-9.13841323e+29, +NegBinomial,identity,cauchy,glm,sqrt,no,yes,-388842.073623669, +NegBinomial,identity,cauchy,glm,sqrt,yes,no,-367462.837335601, +NegBinomial,identity,cauchy,glm,sqrt,yes,yes,-401002.792038126, +NegBinomial,identity,normal,glm,identity,no,no,-1.3719282699999999e+31, +NegBinomial,identity,normal,glm,identity,no,yes,-3.113420653e+30, +NegBinomial,identity,normal,glm,identity,yes,no,-1.030154914e+29, +NegBinomial,identity,normal,glm,identity,yes,yes,-9.79779348e+29, +NegBinomial,identity,normal,glm,no,no,no,-340486.548462806, +NegBinomial,identity,normal,glm,no,no,yes,-343714.050764389, +NegBinomial,identity,normal,glm,no,yes,no,-351177.220844878, +NegBinomial,identity,normal,glm,no,yes,yes,-349933.937096783, +NegBinomial,identity,normal,glm,sqrt,no,no,-2.12691371e+30, +NegBinomial,identity,normal,glm,sqrt,no,yes,-391174.375445099, +NegBinomial,identity,normal,glm,sqrt,yes,no,-368469.35067934, +NegBinomial,identity,normal,glm,sqrt,yes,yes,-397507.916638694, +NegBinomial,log,cauchy,glm,identity,no,no,-1.989662341e+30, +NegBinomial,log,cauchy,glm,identity,no,yes,-1.036372456e+31, +NegBinomial,log,cauchy,glm,identity,yes,no,-2.010450599e+30, +NegBinomial,log,cauchy,glm,identity,yes,yes,-1.316857915e+30, +NegBinomial,log,cauchy,glm,no,no,no,-380448.39559676, +NegBinomial,log,cauchy,glm,no,no,yes,-349323.948504472, +NegBinomial,log,cauchy,glm,no,yes,no,-362629.516623754, +NegBinomial,log,cauchy,glm,no,yes,yes,-346290.65076816, +NegBinomial,log,cauchy,glm,sqrt,no,no,-2.011964369e+30, +NegBinomial,log,cauchy,glm,sqrt,no,yes,-4.24865429e+27, +NegBinomial,log,cauchy,glm,sqrt,yes,no,-6.8444834e+30, +NegBinomial,log,cauchy,glm,sqrt,yes,yes,-8.00599797e+28, +NegBinomial,log,normal,glm,identity,no,no,-1.388299267e+31, +NegBinomial,log,normal,glm,identity,no,yes,-3.142916662e+30, +NegBinomial,log,normal,glm,identity,yes,no,-4.26636544e+30, +NegBinomial,log,normal,glm,identity,yes,yes,-7.39602322e+29, +NegBinomial,log,normal,glm,no,no,no,-381118.002197806, +NegBinomial,log,normal,glm,no,no,yes,-347746.221334564, +NegBinomial,log,normal,glm,no,yes,no,-365104.637394631, +NegBinomial,log,normal,glm,no,yes,yes,-348172.505046149, +NegBinomial,log,normal,glm,sqrt,no,no,-2.873839709e+30, +NegBinomial,log,normal,glm,sqrt,no,yes,-2.664405278e+29, +NegBinomial,log,normal,glm,sqrt,yes,no,-1.105371683e+30, +NegBinomial,log,normal,glm,sqrt,yes,yes,-3.83683482e+29, +Poisson,identity,,glm,identity,no,no,-4449717.02186598,no +Poisson,identity,,glm,identity,no,yes,-4108411.88453109,no +Poisson,identity,,glm,identity,yes,no,-4353577.95890777,no +Poisson,identity,,glm,identity,yes,yes,-4109790.87588076,no +Poisson,identity,,glm,no,no,yes,-3488538.13162666,no +Poisson,identity,,glm,no,yes,no,-4155804.86816874,no +Poisson,identity,,glm,no,yes,yes,-3518018.38972167,no +Poisson,identity,,glm,sqrt,no,no,-4104049.0911004,no +Poisson,identity,,glm,sqrt,no,yes,-4323732.54688323,no +Poisson,identity,,glm,sqrt,yes,no,-4294809.62504604,no +Poisson,identity,,glm,sqrt,yes,yes,-4358462.35023429,no +Poisson,identity,,glm,identity,no,no,-8311816.51254774,yes +Poisson,identity,,glm,identity,no,yes,-8834548.71091253,yes +Poisson,identity,,glm,identity,yes,no,-9006201.83601147,yes +Poisson,identity,,glm,identity,yes,yes,-9023548.67094582,yes +Poisson,identity,,glm,no,no,no,-8353275.77795275,yes +Poisson,identity,,glm,no,no,yes,-8668584.79171507,yes +Poisson,identity,,glm,no,yes,no,-8831608.42588385,yes +Poisson,identity,,glm,no,yes,yes,-8949564.95534642,yes +Poisson,identity,,glm,sqrt,no,no,-8299482.55390966,yes +Poisson,identity,,glm,sqrt,no,yes,-8861837.42205457,yes +Poisson,identity,,glm,sqrt,yes,no,-8895855.40238863,yes +Poisson,identity,,glm,sqrt,yes,yes,-9133041.26893658,yes +Poisson,log,,glm,identity,no,no,-5806768.42789003,no +Poisson,log,,glm,identity,no,yes,-4677882.35234932,no +Poisson,log,,glm,identity,yes,no,-5159621.73009443,no +Poisson,log,,glm,identity,yes,yes,-4591177.8475937,no +Poisson,log,,glm,no,no,yes,-4023304.22398634,no +Poisson,log,,glm,no,yes,no,-5209924.47530365,no +Poisson,log,,glm,no,yes,yes,-4009839.24889643,no +Poisson,log,,glm,sqrt,no,no,-5075745.88721257,no +Poisson,log,,glm,sqrt,no,yes,-4895470.90332615,no +Poisson,log,,glm,sqrt,yes,no,-4919220.00747565,no +Poisson,log,,glm,sqrt,yes,yes,-4834552.27589766,no +Poisson,log,,glm,identity,no,no,-8709396.93540791,yes +Poisson,log,,glm,identity,no,yes,-8894288.2431318,yes +Poisson,log,,glm,identity,yes,no,-8906591.44123111,yes +Poisson,log,,glm,identity,yes,yes,-9150180.41319338,yes +Poisson,log,,glm,no,no,no,-8769770.22548375,yes +Poisson,log,,glm,no,no,yes,-8745912.44654967,yes +Poisson,log,,glm,no,yes,no,-8825991.88615349,yes +Poisson,log,,glm,no,yes,yes,-8901520.76074608,yes +Poisson,log,,glm,sqrt,no,no,-8510746.93459638,yes +Poisson,log,,glm,sqrt,no,yes,-9112105.26498401,yes +Poisson,log,,glm,sqrt,yes,no,-8995079.90780736,yes +Poisson,log,,glm,sqrt,yes,yes,-9193132.58758593,yes From 3eaccedbf2afe65b162b2c2de722d3b22cc17c09 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Fri, 18 Feb 2022 11:19:48 +0900 Subject: [PATCH 20/23] Update prob search --- .../bayesian_probabilistic_search.py | 52 +++++++++++++++++-- search_algorithms/roach_df.csv | 2 +- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/search_algorithms/bayesian_probabilistic_search.py b/search_algorithms/bayesian_probabilistic_search.py index 45be2c0..d6b5220 100644 --- a/search_algorithms/bayesian_probabilistic_search.py +++ b/search_algorithms/bayesian_probabilistic_search.py @@ -10,6 +10,7 @@ import matplotlib.pyplot as plt from sklearn.linear_model import LinearRegression +expectation_values = [] def plot_probabilities(df, iteration): plt.figure() @@ -27,6 +28,12 @@ def plot_probabilities(df, iteration): ax1.set_xlabel("probability") ax1.set_ylabel("ELPD") filtered = df[~df.elpd.isna() & df.selected] + #filtered = filtered[filtered.elpd > -1e+5] + if filtered.shape[0] == 0: + return + + #filtered.loc[:, "elpd"] = filtered.elpd / 1e+30 + #filtered = filtered[~filtered.elpd.isna() & filtered.elpd > -1e+5] linear_regressor = LinearRegression() linear_regressor.fit(filtered.probability.values.reshape(-1, 1), filtered.elpd.values.reshape(-1, 1)) ax1.scatter(filtered.probability, filtered.elpd) @@ -36,6 +43,8 @@ def plot_probabilities(df, iteration): ax2.set_xlabel("probability") ax2.set_ylabel("ELPD") filtered = df[~df.elpd.isna()] + expectation_values.append(np.dot(filtered.probability.values, filtered.elpd.values)) + #filtered.loc[:, "elpd"] = filtered.elpd / 1e+30 linear_regressor = LinearRegression() linear_regressor.fit(filtered.probability.values.reshape(-1, 1), filtered.elpd.values.reshape(-1, 1)) ax2.scatter(filtered.probability, filtered.elpd) @@ -73,7 +82,8 @@ def plot_signatures(df): ################ -def bayesian_probabilistic_score_based_search(model_path, data_path, model_df, num_iterations=10): +def bayesian_probabilistic_score_based_search(model_path, data_path, model_df_path, num_iterations=10): + model_df = elpd_df.read_csv(model_df_path) model_count = model_df.shape[0] model_df["probability"] = 1.0 / model_count model_df["selected"] = False @@ -85,6 +95,8 @@ def bayesian_probabilistic_score_based_search(model_path, data_path, model_df, n for i in range(model_count): for signature in list(model_df.drop(columns=["probability", "selected"]).columns): implementation = model_df.loc[i, signature] + if signature not in Implementation_score_dict: + Implementation_score_dict[signature] = {} Implementation_score_dict[signature][implementation] = [0,0] # iteration 시작 for iteration in range(num_iterations): @@ -108,6 +120,7 @@ def bayesian_probabilistic_score_based_search(model_path, data_path, model_df, n # (2) for each signature, compute its score (store it in a dictionary. ex) implementation: [ELPD score sum, number of occurences]) for signature in list(model_df.drop(columns=["probability", "selected"]).columns): implementation = model_df.loc[draw.index, signature] + implementation = implementation.values[0] Implementation_score_dict[signature][implementation] = [Implementation_score_dict[signature][implementation][0]+elpd,Implementation_score_dict[signature][implementation][1]+1] # (3) calculate the score of each model Score = [] @@ -117,16 +130,35 @@ def bayesian_probabilistic_score_based_search(model_path, data_path, model_df, n if Implementation_score_dict[signature][model_df.loc[i,signature]][1] > 0: model_score +=Implementation_score_dict[signature][model_df.loc[i,signature]][0]/Implementation_score_dict[signature][model_df.loc[i,signature]][1] else: - sum_scores_already_computed_implementations = sum([val[0]/val[1] if val[1] > 0 else 0 for val in Implementation_score_dict[signature].values]) - num_zero_occurence_implementations = sum([1 if val[1] == 0 else 0 for val in Implementation_score_dict[signature].values]) + sum_scores_already_computed_implementations = sum([val[0]/val[1] if val[1] > 0 else 0 for val in Implementation_score_dict[signature].values()]) + num_zero_occurence_implementations = sum([1 if val[1] == 0 else 0 for val in Implementation_score_dict[signature].values()]) model_score += sum_scores_already_computed_implementations/num_zero_occurence_implementations Score.append(model_score) + + min_score = min(Score) + if min_score < 0: + for x in range(len(Score)): + Score[x] = Score[x] - min_score # (4) calculate the sum of the scores of models Score_sum = sum(Score) # (5) update the probability of each model for i in range(model_count): model_df.loc[i, "probability"] = Score[i] / Score_sum + + plot_probabilities(model_df, iteration) + + + plot_signatures(model_df) + plt.figure() + print(expectation_values) + plt.plot(list(range(len(expectation_values))), expectation_values, "ro") + linear_regressor = LinearRegression() + linear_regressor.fit(np.array(list(range(len(expectation_values)))).reshape(-1, 1), np.array(expectation_values).reshape(-1, 1)) + plt.plot(np.array(list(range(len(expectation_values)))), linear_regressor.predict(np.array(list(range(len(expectation_values)))).reshape(-1, 1)), color="red") + plt.savefig("score_based_iteration-expectation_plot.png") + final_model = draw.drop(columns=["probability", "selected"]) + elpd_df.save_csv(model_df, "score_based_results.csv") return final_model @@ -207,16 +239,26 @@ def bayesian_probabilstic_search(model_path, data_path, model_df_path, num_itera elpd_df.save_csv(model_df, "bayesian_update_results.csv") plot_signatures(model_df) + plt.figure() + print(expectation_values) + plt.plot(list(range(len(expectation_values))), expectation_values, "ro") + linear_regressor = LinearRegression() + linear_regressor.fit(np.array(list(range(len(expectation_values)))).reshape(-1, 1), np.array(expectation_values).reshape(-1, 1)) + plt.plot(np.array(list(range(len(expectation_values)))), linear_regressor.predict(np.array(list(range(len(expectation_values)))).reshape(-1, 1)), color="red") + plt.savefig("probability_based_iteration-expectation_plot.png") if __name__ == "__main__": example_dir = pathlib.Path(__file__).resolve().parents[1].absolute().joinpath("examples") + # birthday_model_path = example_dir.joinpath("birthday/birthday.m.stan") # birthday_data_path = example_dir.joinpath("birthday/births_usa_1969.json") # birthday_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("birthday_df.csv") - # bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=20) + #bayesian_probabilstic_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=20) + #bayesian_probabilistic_score_based_search(birthday_model_path, birthday_data_path, birthday_df_path, num_iterations=20) roach_model_path = example_dir.joinpath("roach/roach.m.stan") roach_data_path = example_dir.joinpath("roach/roach.json") roach_df_path = pathlib.Path(__file__).resolve().parent.absolute().joinpath("roach_df.csv") - bayesian_probabilstic_search(roach_model_path, roach_data_path, roach_df_path, num_iterations=90) + # bayesian_probabilstic_search(roach_model_path, roach_data_path, roach_df_path, num_iterations=20) + bayesian_probabilistic_score_based_search(roach_model_path, roach_data_path, roach_df_path, num_iterations=20) diff --git a/search_algorithms/roach_df.csv b/search_algorithms/roach_df.csv index ba02a93..303cf04 100644 --- a/search_algorithms/roach_df.csv +++ b/search_algorithms/roach_df.csv @@ -11,7 +11,7 @@ NegBinomial,identity,cauchy,glm,sqrt,no,no,-9.13841323e+29, NegBinomial,identity,cauchy,glm,sqrt,no,yes,-388842.073623669, NegBinomial,identity,cauchy,glm,sqrt,yes,no,-367462.837335601, NegBinomial,identity,cauchy,glm,sqrt,yes,yes,-401002.792038126, -NegBinomial,identity,normal,glm,identity,no,no,-1.3719282699999999e+31, +NegBinomial,identity,normal,glm,identity,no,no,-1.37192827e+31, NegBinomial,identity,normal,glm,identity,no,yes,-3.113420653e+30, NegBinomial,identity,normal,glm,identity,yes,no,-1.030154914e+29, NegBinomial,identity,normal,glm,identity,yes,yes,-9.79779348e+29, From 72fdd62c612dcfec087b7c8dd9df1feab245bb39 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Tue, 22 Feb 2022 10:39:55 +0900 Subject: [PATCH 21/23] Update notebook --- search_algorithms/model_space_exploration.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/search_algorithms/model_space_exploration.ipynb b/search_algorithms/model_space_exploration.ipynb index 8626cb7..f22dbbe 100644 --- a/search_algorithms/model_space_exploration.ipynb +++ b/search_algorithms/model_space_exploration.ipynb @@ -41,8 +41,8 @@ "- piecewise linear approximation을 하되, interval을 2, 3, ... 으로 다양한 value를 시도함\n", "- approximation error를 조사하여 일관된 변화량을 가지고 있는지 확인함\n", "\n", - "### Chain - Apex Predator Search\n", - "### Chain - search-dynamic step size\n", + "### Chain - Apex Predator Search - 김신영\n", + "### Chain - search-dynamic step size - 차상윤\n", "- stepsize algorithm의 결과 elpd추세 조사" ] }, @@ -51,8 +51,8 @@ "metadata": {}, "source": [ "## Bayesian Approach\n", - "### Bayesian - Probabiliy based search\n", - "### Bayesian - Score based search" + "### Bayesian - Probabiliy based search - 김신영\n", + "### Bayesian - Score based search - 차상윤" ] }, { From 639adf97b70227f0c90fd5cca2baea7ec2dcc7d1 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Tue, 22 Feb 2022 12:08:44 +0900 Subject: [PATCH 22/23] Update chain algorithms and notebook --- .../Chain_Generation_and_Search.py | 85 ++++++++++--------- search_algorithms/apex-predator-search.py | 62 +++++++------- .../model_space_exploration.ipynb | 3 + 3 files changed, 79 insertions(+), 71 deletions(-) diff --git a/search_algorithms/Chain_Generation_and_Search.py b/search_algorithms/Chain_Generation_and_Search.py index 15961d2..30dd6a6 100644 --- a/search_algorithms/Chain_Generation_and_Search.py +++ b/search_algorithms/Chain_Generation_and_Search.py @@ -2,6 +2,9 @@ import random import subprocess import sys +import pathlib +import elpd_df + # Generate a chain given the Top_level_Signature_Hierarchy information # The following function returns a chain of models where for any i, the ith model's model complexity is strictly higher than that of the (i+1)th model @@ -45,42 +48,16 @@ def Chain_Generation(Top_level_Signature_Hierarchy): -def text_command(args): - """Run a shell command, return its stdout as a String or throw an exception if it fails.""" - - try: - result = subprocess.run(args, text=True, check=True, - stderr=subprocess.STDOUT, stdout=subprocess.PIPE) - - stdout = result.stdout.strip() - return stdout - except subprocess.CalledProcessError as exc: - sys.exit("Error in `mstan`: \"" + exc.output.strip() + "\"") - - -class ModelEvaluator: - def __init__(self, dataFile): - self.dataFile = dataFile - - def score(self, modelPath): - """Return the numerical score for the Stan program at the given filepath""" - stdout_result = text_command(["Rscript", "elpd.R", modelPath, self.dataFile]) - return float(stdout_result.split('\n')[-1].strip()) - # Compute the ELPD value of a model # model is a list where its ith element is a string that represents the implementation for the ith top-level signature -def ELPD(model, data_file): +def ELPD(model): # use the elements of 'model' (type: list) to obtain the full 'name' of the model # Then use 'STAN' to compute ELPD of the model (based on the 'model name' obtained above) - model_code_args = ["mstan", "-f", "birthday.m.stan", "concrete-model", "-s", ",".join(model) + ",Regression:glm",] - model_code = text_command(model_code_args) - with open("temp_stanmodel.stan", "w") as f: - f.write(model_code) - result = ModelEvaluator(data_file).score("temp_stanmodel.stan") - print(f"model: {','.join(model)} ELPD:{result}") - return result + df = elpd_df.read_csv("roach_df.csv") + result = elpd_df.search_df(df, model_string=",".join(model)) + return result.elpd.values[0] # Chain is a list whose elements are individual models. Each model is a list that consists of the implementations of the top-level signatures. @@ -90,7 +67,7 @@ def ELPD(model, data_file): # and finally returns the model with the highest ELPD value and its ELPD value (among those that the ELPD values were computed) # Suppose that the chain is given as an input and the models are sorted in a decreasing order of model complexity. # i.e. the 1st model in the chain has the highest model complexity and the last model has the lowest complexity. -def Chain_Search(Chain,K,data_file_dir, alpha=0.5): +def Chain_Search(Chain,K, alpha=0.5): n = len(Chain) # if number of models in the chain is smaller than or equal to K, we can compute ELPD values of each model and choose the one with the highest value if n <=K: @@ -106,23 +83,28 @@ def Chain_Search(Chain,K,data_file_dir, alpha=0.5): ELPD_values_obtained = [] while num_ELPD_computed < K and cur_ind < n: # compute the ELPD value of the current iteration's model - cur_iter_ELPD = ELPD(Chain[cur_ind], data_file_dir) + print(f"current index: {cur_ind}") + cur_iter_ELPD = ELPD(Chain[cur_ind]) # update the ELPD compute model indices, ELPD values, the nubmer of ELPD values computed obtained respectively ELPD_computed_model_indices.append(cur_ind) ELPD_values_obtained.append(cur_iter_ELPD) num_ELPD_computed+=1 + if num_ELPD_computed == K: + break step_size = 1 # set default step size as 1 # if it is neither the 1st nor the 2nd iteration, the step size should be modified. if not (num_ELPD_computed == 1 or num_ELPD_computed==2): step_size_Uniform = (n-1-cur_ind)/(K-num_ELPD_computed) step_size_LB = step_size_Uniform*(1-alpha) step_size_UB = step_size_Uniform*(1+alpha) + print(ELPD_values_obtained) ELPD_slope_cur_iter = abs((ELPD_values_obtained[-1]-ELPD_values_obtained[-2]))/(ELPD_computed_model_indices[-1]-ELPD_computed_model_indices[-2]) ELPD_slope_previous_iter = abs((ELPD_values_obtained[-2]-ELPD_values_obtained[-3]))/(ELPD_computed_model_indices[-2]-ELPD_computed_model_indices[-3]) step_size_candidate = step_size_Uniform*(ELPD_slope_previous_iter/ELPD_slope_cur_iter) step_size = min(step_size_UB,min(step_size_LB,step_size_candidate)) + print(step_size) # update the current index (which would the index of the model in the next iteration) - cur_ind += step_size + cur_ind += max(1, round(step_size)) # find the best model (the model with the highest ELPD value) among the models whose ELPD values were computed. highest_ELPD_value_model_chain_search_ind = np.argmax(ELPD_values_obtained) Final_best_ELPD_model_ind, Final_best_ELPD_val = ELPD_computed_model_indices[highest_ELPD_value_model_chain_search_ind], ELPD_values_obtained[highest_ELPD_value_model_chain_search_ind] @@ -135,20 +117,39 @@ def Chain_Search(Chain,K,data_file_dir, alpha=0.5): Top_level_Signature_Hierarchy = [ ["DayOfWeekTrend:yes,DayOfWeekWeights:weighted","DayOfWeekTrend:yes,DayOfWeekWeights:uniform","DayOfWeekTrend:no"], ["DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:yes","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:yes,DayOfYearNormalVariance:no","DayOfYearTrend:yes,DayOfYearHeirarchicalVariance:no,DayOfYearNormalVariance:no","DayOfYearTrend:no"], - ["HolidayTrend:yes","HolidayTrend:yes"], + ["HolidayTrend:yes","HolidayTrend:no"], ["LongTermTrend:yes","LongTermTrend:no"], ["SeasonalTrend:yes","SeasonalTrend:no"] ] -chain = Chain_Generation(Top_level_Signature_Hierarchy) -for v in chain: - print(v) -data_file_dir = "examples/birthday/births_usa_1969.json" - -K = 3 +Top_level_Signature_Hierarchy_Roach = [ + ["Likelihood:NegBinomial,PhiPrior:normal", "Likelihood:NegBinomial,PhiPrior:cauchy", "Likelihood:Poisson,RandomEffects:yes", "Likelihood:Poisson,RandomEffects:no"], + ["OffsetType:log", "OffsetType:identity"], + ["Roach:sqrt", "Roach:identity", "Roach:no"], + ["Treatment:yes", "Treatment:no"], + ["Senior:yes", "Senior:no"] +] -#best_model, best_elpd = Chain_Search(Chain=chain, K=K, data_file_dir=data_file_dir) -#print(best_model, best_elpd) +#chain = Chain_Generation(Top_level_Signature_Hierarchy) +chain = Chain_Generation(Top_level_Signature_Hierarchy_Roach) +print(f"chain length: {len(chain)}") +for val in chain: + model_string=",".join(val) + df = elpd_df.read_csv("roach_df.csv") + result = elpd_df.search_df(df, model_string=model_string) + print(val, result) + +print("-" * 10) +K = 5 +alpha = 0.9 +# alpha: [0, 1] 값이 클수록 uniform stepsize와 근접한 stepsize(=비교적 균일한 stepsize)로 움직임. +# 넓은 영역에 대해 선형근사가 잘 된다고 생각되면(diverse model space) 큰 alpha를 사용하여 도약을 크게 가짐. +# 모델 공간이 균질한 경우 alpha를 작게 만들어 보수적으로 탐색함(subtle model changes may lead to large elpd differences) +# 0.1 vs 0.9 양극단에 있는 값을 사용해야 그나마 stepsize의 차이가 관찰됨 +# k는 다양한 값을 시도해봄(m>=k>=3) + +best_model, best_elpd = Chain_Search(Chain=chain, K=K, alpha=alpha) +print(best_model, best_elpd) # Original Chain Generation Algorithm (Proposed earlier) diff --git a/search_algorithms/apex-predator-search.py b/search_algorithms/apex-predator-search.py index f6d92c2..06d61da 100644 --- a/search_algorithms/apex-predator-search.py +++ b/search_algorithms/apex-predator-search.py @@ -1,4 +1,6 @@ import subprocess +import pathlib +import elpd_df def text_command(args): """Run a shell command, return its stdout as a String or throw an exception if it fails.""" @@ -14,17 +16,25 @@ def text_command(args): class ModelEvaluator: - def __init__(self, dataFile): - self.dataFile = dataFile + def __init__(self, df_path): + self.df_path = df_path + self.df = elpd_df.read_csv(df_path) - def score(self, modelPath): + def score(self, model_string): """Return the numerical score for the Stan program at the given filepath""" - stdout_result = text_command(["Rscript", "elpd.R", modelPath, self.dataFile]) - return float(stdout_result.split('\n')[-1].strip()) + elpd = elpd_df.search_df(self.df, model_string=model_string).elpd.values[0] + return elpd -model_file_name = "examples/birthday/birthday.m.stan" -args = ["mstan", "-f", model_file_name, "get-highest-models"] +example_dir = pathlib.Path(__file__).resolve().parents[1].absolute().joinpath("examples") +model_file_path = example_dir.joinpath("birthday/birthday.m.stan") +model_df_path = "birthday_df.csv" + +model_file_path = example_dir.joinpath("roach/roach.m.stan") +model_df_path = "roach_df.csv" + + +args = ["mstan", "-f", model_file_path, "get-highest-models"] result = subprocess.run(args, text=True, check=True, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) @@ -34,32 +44,26 @@ def score(self, modelPath): results = {} for model in stdout: - model_code_args = ["mstan", "-f", model_file_name, "concrete-model", "-s", model,] - print(model_code_args) - model_code = subprocess.run(model_code_args, text=True, check=True, stderr=subprocess.STDOUT, stdout=subprocess.PIPE).stdout.strip() - with open("temp_stanmodel.stan", "w") as f: - f.write(model_code) - - result = ModelEvaluator("examples/birthday/births_usa_1969.json").score("temp_stanmodel.stan") + result = ModelEvaluator(model_df_path).score(model) results[model] = result -print(results) - +for key, val in results.items(): + print(f"{key} : {val}") -hierarchy_info = [ - ["DayofWeekTrend:yes,DayofWeekWeights:weighted", "DayofWeekTrend:yes,DayofWeekWeights:uniform", "DayofWeekTrend:no"], - ["DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:yes","DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:no","DayofYearTrend:yes,DayofHierarchicalVariance:no,DayofYearNormalVariance:yes", "DayofYearTrend:no"] - ["HolidayTrend:yes", "HolidayTrend:no"], - ["LongTermTrend:yes", "LongTermTrend:no"] - ["SeasonTrend:yes", "SeasonTrend:no"] - #... -] # n, n-1, ... 1 +# hierarchy_info = [ +# ["DayofWeekTrend:yes,DayofWeekWeights:weighted", "DayofWeekTrend:yes,DayofWeekWeights:uniform", "DayofWeekTrend:no"], +# ["DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:yes","DayofYearTrend:yes,DayofHierarchicalVariance:yes,DayofYearNormalVariance:no","DayofYearTrend:yes,DayofHierarchicalVariance:no,DayofYearNormalVariance:yes", "DayofYearTrend:no"] +# ["HolidayTrend:yes", "HolidayTrend:no"], +# ["LongTermTrend:yes", "LongTermTrend:no"] +# ["SeasonTrend:yes", "SeasonTrend:no"] +# #... +# ] # n, n-1, ... 1 -current_model = ["DayofWeek:Yes", "HolidayTrend:Yes"] +# current_model = ["DayofWeek:Yes", "HolidayTrend:Yes"] -chain = [] -chain.append(",".join(current_model)) -current_model[0] = hierarchy_info[0][1] -chain.append(",".join(current_model)) \ No newline at end of file +# chain = [] +# chain.append(",".join(current_model)) +# current_model[0] = hierarchy_info[0][1] +# chain.append(",".join(current_model)) \ No newline at end of file diff --git a/search_algorithms/model_space_exploration.ipynb b/search_algorithms/model_space_exploration.ipynb index f22dbbe..15f1ba0 100644 --- a/search_algorithms/model_space_exploration.ipynb +++ b/search_algorithms/model_space_exploration.ipynb @@ -64,6 +64,9 @@ "- Value-based comparison: ex) Objective function (ELPD value of the best model)\n", "- Sensitivity Analysis(tuning hyperparameter of algorithms) algorithm-wise vs value-wise\n", "\n", + "**bayesian algorithm vs chain algorithm 시행시 chain여러개를 만들어 비슷한 계산량 선상에서 비교할 것**\n", + "추가: k = 5 vs k = 10 vs k = 15 vs k = 20\n", + "\n", "각 알고리즘을 n번 시행해서 평균 또는 best out of n\n", "\n", "|algorithm | Best Model ELPD| Iteration to reach best model|\n", From a4f090db6d95e6648c1d5556d47d5636e1bbd3f4 Mon Sep 17 00:00:00 2001 From: Dashadower Date: Tue, 22 Feb 2022 12:15:56 +0900 Subject: [PATCH 23/23] Update notebook --- .../model_space_exploration.ipynb | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/search_algorithms/model_space_exploration.ipynb b/search_algorithms/model_space_exploration.ipynb index 15f1ba0..da9b26d 100644 --- a/search_algorithms/model_space_exploration.ipynb +++ b/search_algorithms/model_space_exploration.ipynb @@ -64,8 +64,8 @@ "- Value-based comparison: ex) Objective function (ELPD value of the best model)\n", "- Sensitivity Analysis(tuning hyperparameter of algorithms) algorithm-wise vs value-wise\n", "\n", - "**bayesian algorithm vs chain algorithm 시행시 chain여러개를 만들어 비슷한 계산량 선상에서 비교할 것**\n", - "추가: k = 5 vs k = 10 vs k = 15 vs k = 20\n", + "**bayesian algorithm vs chain algorithm 시행시 chain여러개를 만들어 비슷한 계산량 선상에서 비교할 것(추가: k = 5 vs k = 10 vs k = 15 vs k = 20)**\n", + "\n", "\n", "각 알고리즘을 n번 시행해서 평균 또는 best out of n\n", "\n", @@ -82,7 +82,17 @@ "|Score(K=10) |.|.|\n", "\n", "- bayesian algorithms: 확률분포 비교(산포도), expectation, 추세선\n", - "- iteration-ELPD plot, iteration-expected_ELPD plot\n" + "- iteration-ELPD plot, iteration-expected_ELPD plot\n", + "\n", + "chain algorithm: 제한된 계산자원 내에서 \"적당한\" 모델을 선택한는데 탁월함. 다만 정확한 계층정보를 필요로 함.\n", + "\n", + "bayesian algorithm: 더 여유로운 계산자원 한도 내에서 전체 모델 공간에 대해 모델별 성능을 파악할 수 있음. 계층 정보 및 모델 공간에 대한 특별한 가정을 필요로 하지 않음.\n", + "\n", + "- case 1: 계산자원이 제한적일 때\n", + " - chain 알고리즘을 사용\n", + "- case 2: 계산자원이 여유로울 때\n", + " - 1: bayesian 알고리즘을 사용\n", + " - 2: 전체 모델 공간을 몇개의 체인으로 분해하여 각각에 대해 체인 알고리즘을 적용" ] }, { @@ -93,6 +103,8 @@ "\n", "실험의 편의를 위해 ELPD를 한번만 계산했기에 ELPD계산 관련한 노이즈가 반영이 안되었음\n", "\n", + "chain 알고리즘의 경우, signature내에서의 계층정보가 필요하지만, 모든 모델에 대한 계층 정보를 파악하는건 쉽지 않음.\n", + "\n", "## Further Research Topics\n", "## Conclusion" ]