digits: Train random forest

master
neingeist 9 years ago
parent c657e68de6
commit da678177cb

@ -103,6 +103,19 @@ class SVM(StatModel):
def predict(self, samples):
return self.model.predict_all(samples).ravel()
class RForest(StatModel):
def __init__(self):
self.params = dict( max_depth=20 )
self.model = cv2.RTrees()
def train(self, samples, responses):
self.model = cv2.RTrees()
self.model.train(samples, cv2.CV_ROW_SAMPLE, responses,
params=self.params)
def predict(self, samples):
predictions = map(self.model.predict, samples)
return predictions
def evaluate_model(model, digits, samples, labels):
resp = model.predict(samples)
@ -171,6 +184,12 @@ if __name__ == '__main__':
samples_train, samples_test = np.split(samples, [train_n])
labels_train, labels_test = np.split(labels, [train_n])
print 'training Random Forest...'
model = RForest()
model.train(samples_train, labels_train)
vis = evaluate_model(model, digits_test, samples_test, labels_test)
cv2.imshow('Random Forest test', vis)
print 'training KNearest...'
model = KNearest(k=4)
@ -192,4 +211,5 @@ if __name__ == '__main__':
print 'saving SVM as "digits_svm.dat"...'
model.save('digits_svm.dat')
cv2.waitKey(0)

Loading…
Cancel
Save