digits: Train random forest
This commit is contained in:
parent
c657e68de6
commit
da678177cb
1 changed files with 20 additions and 0 deletions
20
digits.py
20
digits.py
|
@ -103,6 +103,19 @@ class SVM(StatModel):
|
|||
def predict(self, samples):
|
||||
return self.model.predict_all(samples).ravel()
|
||||
|
||||
class RForest(StatModel):
|
||||
def __init__(self):
|
||||
self.params = dict( max_depth=20 )
|
||||
self.model = cv2.RTrees()
|
||||
|
||||
def train(self, samples, responses):
|
||||
self.model = cv2.RTrees()
|
||||
self.model.train(samples, cv2.CV_ROW_SAMPLE, responses,
|
||||
params=self.params)
|
||||
|
||||
def predict(self, samples):
|
||||
predictions = map(self.model.predict, samples)
|
||||
return predictions
|
||||
|
||||
def evaluate_model(model, digits, samples, labels):
|
||||
resp = model.predict(samples)
|
||||
|
@ -171,6 +184,12 @@ if __name__ == '__main__':
|
|||
samples_train, samples_test = np.split(samples, [train_n])
|
||||
labels_train, labels_test = np.split(labels, [train_n])
|
||||
|
||||
print 'training Random Forest...'
|
||||
|
||||
model = RForest()
|
||||
model.train(samples_train, labels_train)
|
||||
vis = evaluate_model(model, digits_test, samples_test, labels_test)
|
||||
cv2.imshow('Random Forest test', vis)
|
||||
|
||||
print 'training KNearest...'
|
||||
model = KNearest(k=4)
|
||||
|
@ -192,4 +211,5 @@ if __name__ == '__main__':
|
|||
print 'saving SVM as "digits_svm.dat"...'
|
||||
model.save('digits_svm.dat')
|
||||
|
||||
|
||||
cv2.waitKey(0)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue