"""ICA Example 2: Whitening, scatter plots, and projection histograms.

Python equivalent of matlabex2.m
Requires: numpy, matplotlib, scipy
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm, inv

POINTS = 1000

# Define two uniform random variables
rng = np.random.default_rng()
A = np.round(rng.uniform(0, 99, POINTS)).astype(float) - 50
B = np.round(rng.uniform(0, 99, POINTS)).astype(float) - 50

fig, ax = plt.subplots()
ax.plot(A, B, ".")
ax.set_xlim(-80, 80)
ax.set_ylim(-80, 80)
ax.set_title("Original sources A and B")
ax.set_aspect("equal")
plt.show()

# Mix linearly
M1 = 0.54 * A - 0.84 * B
M2 = 0.42 * A + 0.27 * B

fig, ax = plt.subplots()
ax.plot(M1, M2, ".")
ax.set_ylim(ax.get_xlim())
ax.set_title("Two linear mixtures")
ax.set_aspect("equal")
plt.show()

# Whiten the data
x = np.vstack([M1, M2])
c = np.cov(x)
sq = inv(sqrtm(c))
mx = x.mean(axis=1, keepdims=True)
xx = 2 * sq @ (x - mx)
print("Covariance after whitening:\n", np.cov(xx).round(3))

fig, ax = plt.subplots()
ax.plot(xx[0], xx[1], ".")
ax.set_title("Whitened data")
ax.set_aspect("equal")
plt.show()

# Projections of whitened data (Gaussian-looking)
fig = plt.figure(figsize=(6, 6))
ax_main = fig.add_axes([0.2, 0.2, 0.75, 0.75])
ax_left = fig.add_axes([0.0, 0.2, 0.18, 0.75])
ax_bot = fig.add_axes([0.2, 0.0, 0.75, 0.18])

ax_main.plot(xx[0], xx[1], ".")
ax_main.set_title("Whitened data with projections")
ax_left.hist(xx[1], bins=30, orientation="horizontal")
ax_left.invert_xaxis()
ax_bot.hist(xx[0], bins=30)
ax_left.set_yticks([])
ax_bot.set_xticks([])
plt.show()

# Projections of original sources (non-Gaussian)
fig = plt.figure(figsize=(6, 6))
ax_main = fig.add_axes([0.2, 0.2, 0.75, 0.75])
ax_left = fig.add_axes([0.0, 0.2, 0.18, 0.75])
ax_bot = fig.add_axes([0.2, 0.0, 0.75, 0.18])

ax_main.plot(A, B, ".")
ax_main.set_title("Original sources with projections")
ax_left.hist(B, bins=30, orientation="horizontal")
ax_left.invert_xaxis()
ax_bot.hist(A, bins=30)
ax_left.set_yticks([])
ax_bot.set_xticks([])
plt.show()
