digits: Train random forest
This commit is contained in:
		
							parent
							
								
									c657e68de6
								
							
						
					
					
						commit
						da678177cb
					
				
					 1 changed files with 20 additions and 0 deletions
				
			
		
							
								
								
									
										20
									
								
								digits.py
									
										
									
									
									
								
							
							
						
						
									
										20
									
								
								digits.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -103,6 +103,19 @@ class SVM(StatModel):
 | 
			
		|||
    def predict(self, samples):
 | 
			
		||||
        return self.model.predict_all(samples).ravel()
 | 
			
		||||
 | 
			
		||||
class RForest(StatModel):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.params = dict( max_depth=20 )
 | 
			
		||||
        self.model = cv2.RTrees()
 | 
			
		||||
 | 
			
		||||
    def train(self, samples, responses):
 | 
			
		||||
        self.model = cv2.RTrees()
 | 
			
		||||
        self.model.train(samples, cv2.CV_ROW_SAMPLE, responses,
 | 
			
		||||
              params=self.params)
 | 
			
		||||
 | 
			
		||||
    def predict(self, samples):
 | 
			
		||||
        predictions = map(self.model.predict, samples)
 | 
			
		||||
        return predictions
 | 
			
		||||
 | 
			
		||||
def evaluate_model(model, digits, samples, labels):
 | 
			
		||||
    resp = model.predict(samples)
 | 
			
		||||
| 
						 | 
				
			
			@ -171,6 +184,12 @@ if __name__ == '__main__':
 | 
			
		|||
    samples_train, samples_test = np.split(samples, [train_n])
 | 
			
		||||
    labels_train, labels_test = np.split(labels, [train_n])
 | 
			
		||||
 | 
			
		||||
    print 'training Random Forest...'
 | 
			
		||||
 | 
			
		||||
    model = RForest()
 | 
			
		||||
    model.train(samples_train, labels_train)
 | 
			
		||||
    vis = evaluate_model(model, digits_test, samples_test, labels_test)
 | 
			
		||||
    cv2.imshow('Random Forest test', vis)
 | 
			
		||||
 | 
			
		||||
    print 'training KNearest...'
 | 
			
		||||
    model = KNearest(k=4)
 | 
			
		||||
| 
						 | 
				
			
			@ -192,4 +211,5 @@ if __name__ == '__main__':
 | 
			
		|||
    print 'saving SVM as "digits_svm.dat"...'
 | 
			
		||||
    model.save('digits_svm.dat')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    cv2.waitKey(0)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue