#!/usr/bin/env python
import numpy as np
import matplotlib.cm as cm
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit


plt.ion()
# Create 1D Gaussian toy data.
np.random.seed(1)  # set random seed
# Draw 10 values from unit Gaussian.
Data = np.random.normal(0.0, 1.0, 10)
# Range of parameter a.
a_min = -2.5
a_max =  2.5
# Range of parameter b.
b_min = -1.0
b_max =  1.0
# Number of steps of grid.
Steps = 51
# Allocate grid as matrix.
Grid  = np.zeros([Steps,Steps])
# Try all parameter combinations.
for s1 in range(Steps):
    for s2 in range(Steps):
        # Current parameter combination.
        a = a_min + (a_max - a_min)*float(s1)/float(Steps-1)
        b = b_min + (b_max - b_min)*float(s2)/float(Steps-1)

        # Evaluate chi-squared.
        chi2 = 0.0
        for n in range(len(Data)):
            # Use index n as pseudo-position
            residual = (Data[n] - a - n*b)
            chi2     = chi2 + residual*residual
        Grid[Steps-1-s2,s1] = chi2

plt.cla()
plt.figure(1, figsize=(8,3))
mini  = np.min(Grid)  # minimal value of chi2
image = plt.imshow(Grid, vmin=mini, vmax=mini+20.0,
                         extent=[a_min,a_max,b_min,b_max])
plt.colorbar(image)
plt.xlabel(r'$a$', fontsize=24)
plt.ylabel(r'$b$', fontsize=24)
plt.savefig('example-chi2-manifold.png')
plt.draw()
raw_input( "Press Enter to continue... " )


# Chose a model that will create bimodality.
def func(x, a, b):
    return a + b*b*x  # Term b*b will create bimodality.

# Create toy data for curve_fit.
xdata = np.array([0.0,1.0,2.0,3.0,4.0,5.0])
ydata = np.array([0.1,0.9,2.2,2.8,3.9,5.1])
sigma = np.array([1.0,1.0,1.0,1.0,1.0,1.0])

# Compute chi-square manifold.
Steps = 101  # grid size
Chi2Manifold = np.zeros([Steps,Steps])  # allocate grid
amin = -7.0  # minimal value of a covered by grid
amax = +5.0  # maximal value of a covered by grid
bmin = -4.0  # minimal value of b covered by grid
bmax = +4.0  # maximal value of b covered by grid
for s1 in range(Steps):
    for s2 in range(Steps):
        # Current values of (a,b) at grid position (s1,s2).
        a = amin + (amax - amin)*float(s1)/(Steps-1)
        b = bmin + (bmax - bmin)*float(s2)/(Steps-1)
        # Evaluate chi-squared.
        chi2 = 0.0
        for n in range(len(xdata)):
            residual = (ydata[n] - func(xdata[n], a, b))/sigma[n]
            chi2 = chi2 + residual*residual
        Chi2Manifold[Steps-1-s2,s1] = chi2  # write result to grid.

# Plot grid.
plt.cla()
plt.figure(1, figsize=(8,4.5))
plt.subplots_adjust(left=0.09, bottom=0.09, top=0.97, right=0.99)
# Plot chi-square manifold.
image = plt.imshow(Chi2Manifold, vmax=50.0,
              extent=[amin, amax, bmin, bmax])
# Plot where curve-fit is going to for a couple of initial guesses.
for a_initial in -6.0, -4.0, -2.0, 0.0, 2.0, 4.0:
    # Initial guess.
    x0   = np.array([a_initial, -3.5])
    xFit = curve_fit(func, xdata, ydata, x0, sigma)[0]
    plt.plot([x0[0], xFit[0]], [x0[1], xFit[1]], 'o-', ms=4,
                 markeredgewidth=0, lw=2, color='orange')
plt.colorbar(image)  # make colorbar
plt.xlim(amin, amax)
plt.ylim(bmin, bmax)
plt.xlabel(r'$a$', fontsize=24)
plt.ylabel(r'$b$', fontsize=24)
plt.savefig('demo-robustness-curve-fit.png')

plt.draw()
raw_input( "Press Enter to continue... " )

from scipy.optimize import fmin as simplex
# "fmin" is not a sensible name for an optimisation package.
# Rename fmin to "simplex"

# Define the objective function to be minimised by Simplex.
# params ... array holding the values of the fit parameters.
# X      ... array holding x-positions of observed data.
# Y      ... array holding y-values of observed data.
# Err    ... array holding errors of observed data.
def func(params, X, Y, Err):
    # extract current values of fit parameters from input array
    a = params[0]
    b = params[1]
    c = params[2]
    # compute chi-square
    chi2 = 0.0
    for n in range(len(X)):
        x = X[n]
        # The function y(x)=a+b*x+c*x^2 is a polynomial
        # in this example.
        y = a + b*x + c*x*x

        chi2 = chi2 + (Y[n] - y)*(Y[n] - y)/(Err[n]*Err[n])
    return chi2

xdata = [0.0,1.0,2.0,3.0,4.0,5.0]
ydata = [0.1,0.9,2.2,2.8,3.9,5.1]
sigma = [1.0,1.0,1.0,1.0,1.0,1.0]

#Initial guess.
x0    = [0.0, 0.0, 0.0]

# Apply downhill Simplex algorithm.
print simplex(func, x0, args=(xdata, ydata, sigma), full_output=0)

#plt.cla()

     
#plt.draw()
raw_input( "Press Enter to continue... " )

