|
|
|
@ -103,6 +103,19 @@ class SVM(StatModel):
|
|
|
|
|
def predict(self, samples):
|
|
|
|
|
return self.model.predict_all(samples).ravel()
|
|
|
|
|
|
|
|
|
|
class RForest(StatModel):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.params = dict( max_depth=20 )
|
|
|
|
|
self.model = cv2.RTrees()
|
|
|
|
|
|
|
|
|
|
def train(self, samples, responses):
|
|
|
|
|
self.model = cv2.RTrees()
|
|
|
|
|
self.model.train(samples, cv2.CV_ROW_SAMPLE, responses,
|
|
|
|
|
params=self.params)
|
|
|
|
|
|
|
|
|
|
def predict(self, samples):
|
|
|
|
|
predictions = map(self.model.predict, samples)
|
|
|
|
|
return predictions
|
|
|
|
|
|
|
|
|
|
def evaluate_model(model, digits, samples, labels):
|
|
|
|
|
resp = model.predict(samples)
|
|
|
|
@ -171,6 +184,12 @@ if __name__ == '__main__':
|
|
|
|
|
samples_train, samples_test = np.split(samples, [train_n])
|
|
|
|
|
labels_train, labels_test = np.split(labels, [train_n])
|
|
|
|
|
|
|
|
|
|
print 'training Random Forest...'
|
|
|
|
|
|
|
|
|
|
model = RForest()
|
|
|
|
|
model.train(samples_train, labels_train)
|
|
|
|
|
vis = evaluate_model(model, digits_test, samples_test, labels_test)
|
|
|
|
|
cv2.imshow('Random Forest test', vis)
|
|
|
|
|
|
|
|
|
|
print 'training KNearest...'
|
|
|
|
|
model = KNearest(k=4)
|
|
|
|
@ -192,4 +211,5 @@ if __name__ == '__main__':
|
|
|
|
|
print 'saving SVM as "digits_svm.dat"...'
|
|
|
|
|
model.save('digits_svm.dat')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv2.waitKey(0)
|
|
|
|
|