-
Notifications
You must be signed in to change notification settings - Fork 22
Description
Description:
I am encountering a TypeError when trying to use a deserialized RandomForestClassifier. The error occurs when calling predictProbability on the rehydrated classifier, specifically in the DecisionTreeClassifier:
DecisionTreeClassifier.js:58 Uncaught (in promise) TypeError: this.root.classify(...).maxRowIndex is not a function
at _DecisionTreeClassifier.predict (DecisionTreeClassifier.js:58:10)
at _RandomForestClassifier.predictionValues (RandomForestBase.js:288:48)
at _RandomForestClassifier.predictProbability (RandomForestClassifier.js:111:35)
Steps to Reproduce:
-
Serialize a
RandomForestClassifierusingtoJSON():const randomForestClassifier = new RandomForestClassifier(/* options */); const classifierJson = randomForestClassifier.toJSON();
-
Deserialize the
RandomForestClassifierusingload():const classifier = new RandomForestClassifier.load({ isClassifier: classifierModel.baseModel.isClassifier, noOOB: classifierModel.baseModel.noOOB, selectionMethod: classifierModel.baseModel.selectionMethod, useSampleBagging: classifierModel.baseModel.useSampleBagging, maxFeatures: classifierModel.baseModel.maxFeatures, nEstimators: classifierModel.baseModel.estimators.length, seed: classifierModel.baseModel.seed, replacement: classifierModel.baseModel.replacement, treeOptions: classifierModel.baseModel.treeOptions, }, classifierJson as RandomForestClassifierModel); classifier.indexes = classifierModel.baseModel.indexes; classifier.n = classifierModel.baseModel.n; classifier.estimators = classifierModel.baseModel.estimators;
-
Call
predictProbability()on the deserializedRandomForestClassifier:const predictions = classifier.predictProbability(testData, 1);
Expected Behavior:
The deserialized RandomForestClassifier should properly function, and the estimators (decision trees) should be rehydrated as instances of DecisionTreeClassifier, allowing predict() to work correctly.
Actual Behavior:
The estimators array contains objects but they do not have the predict() method, and calling predictProbability throws the following error:
TypeError: this.root.classify(...).maxRowIndex is not a function
Cause:
It seems that during deserialization, the estimators (which are instances of DecisionTreeClassifier) are being restored as plain JSON objects, losing their methods such as predict() and classify(). This leads to the failure when the classifier tries to call methods on these objects.
Possible Fix:
To fix this, each estimator in the estimators array needs to be rehydrated as an instance of DecisionTreeClassifier after deserialization. For example:
classifier.estimators = classifierModel.baseModel.estimators.map(estimator => {
return DecisionTreeClassifier.load(estimator); // Rehydrate each estimator
});Environment:
- Package:
ml-random-forest - Version:
2.1.0
Additional Context:
The deserialization works fine up to the point where it reaches the estimators. The issue arises when trying to call methods on these deserialized objects, which are missing their expected methods.