print out better messages

master
neingeist 11 years ago
parent 90b5fb26ec
commit 4cd1ade5d3

@ -17,17 +17,18 @@ yn = [6, 5, 7, 10, 11, 12, 14]
# For each data point the vertical error/residual is x*b1 + b2 - y. We want to # For each data point the vertical error/residual is x*b1 + b2 - y. We want to
# minimize the sum of the squared residuals (least squares). # minimize the sum of the squared residuals (least squares).
S = sum((xn[i] * b1 + b2 - yn[i]) ** 2 for i in range(0, len(xn))) S = sum((xn[i] * b1 + b2 - yn[i]) ** 2 for i in range(0, len(xn)))
print(S) print("Function to minimize: S = {}".format(S))
# Minimize S by setting its partial derivatives to zero. # Minimize S by setting its partial derivatives to zero.
d1 = diff(S, b1) d1 = diff(S, b1)
d2 = diff(S, b2) d2 = diff(S, b2)
solutions = solve([d1, d2], [b1, b2]) solutions = solve([d1, d2], [b1, b2])
print("S is minimal for b1 = {}, b2 = {}".format(solutions[b1], solutions[b2]))
# Construct fitted line from the solutions # Construct fitted line from the solutions
x = Symbol('x') x = Symbol('x')
fitted_line = solutions[b1] * x + solutions[b2] fitted_line = solutions[b1] * x + solutions[b2]
print(fitted_line) print("Fitted line: y = {}".format(fitted_line))
# Construct something we can plot with matplotlib # Construct something we can plot with matplotlib
fitted_line_func = lambdify(x, fitted_line, modules=['numpy']) fitted_line_func = lambdify(x, fitted_line, modules=['numpy'])

Loading…
Cancel
Save