import matplotlib.pyplot as plt
from scipy import stats
import numpy as np
#optional theme for jupyter plots
#from jupyterthemes import jtplot
#jtplot.style(theme='monokai')
#jtplot.reset()
#Como no tengo datos reales me invento una distribucion normal para simular un cluster con 150 estrellas
mean, cov, size = ([0.0, 0.0]), ([0.1, 0.1], [0.1, 0.1]), 150
x = np.random.multivariate_normal(mean, cov, size)[:,0]
y = np.random.multivariate_normal(mean, cov, size)[:,1]
#Agrego un background uniforme de 5000 estrellas
x = np.concatenate((x, np.random.uniform(low=-10, high=5, size=5000)))
y = np.concatenate((y, np.random.uniform(low=-10, high=5, size=5000)))
Como no quiero enredarme con plt.hist2d que devuelve también un plot, uso la libreria de scipy.stats que hace lo mismo y se llama 'binned_statistic_2d'
Aparte que tiene mas funciones que solo contar, que es el hist2d. Se puede seleccionar una statistic cualquiera de las predefinidas o colocarle una funcion de numpy tambien.
bx , by = 11, 11 #selecciono el numero de bins por eje IMPORTANTE: tiene que se impar para que todo funcione :D
##### Por el momento solo funciona con una grilla cuadrada. Estoy viendo por que no funciona si es irregular
stat, x_edges, y_edges, binnumber = stats.binned_statistic_2d(x, y, values='None', statistic='count', bins=[bx,by], expand_binnumbers=True)
dx = x_edges[1]-x_edges[0]
dy = y_edges[1]-y_edges[0]
#Ploteo como se ven los puntos
plt.scatter(x,y, s=1)
plt.xlabel('x', size=17)
plt.ylabel('y', size=17)
plt.show()
Una cosa tecnica antes del siguiente paso:
El paso importante es colocar ${\tt expand-binnumbers=True}$
binnumbers es un arreglo de (2,N) dimenisones que te dice en que parte de la grilla esta cada punto que has ingresado
Si no es ${\tt True}$, retornara el mismo valor pero sumado, y aun no se como descifrarlo y volverlo a la grilla :(
#Ploteo ahora como se ve cada bin
#binned_statistic tiene la gracia de dar un id para cada punto usado en binnnumber
#El plot es bien feo, pero funciona para separar los bines unos de otros
compressed_binnumber = np.sum(binnumber, axis=0) #esto suma el valor de ambos ejes: si una estrella esta en la grilla (5,4) retornara 9
for i, bines in enumerate(set(compressed_binnumber)):
plt.scatter(x[compressed_binnumber == bines],y[compressed_binnumber == bines], s=9)
for i, edges in enumerate(x_edges):
plt.vlines(x_edges[i], np.min(y), np.max(y), color='k', alpha=0.4)
for i, edges in enumerate(y_edges):
plt.hlines(y_edges[i], np.min(x), np.max(x), color='k', alpha=0.4)
plt.show()
#Aqui la magia de argmax y unravel_index que retornan de la grilla (dado por stat.shape) el maximo
argmax = np.unravel_index(np.argmax(stat), stat.shape)
argmax
#Solo seleccionamos las estrellas que por cada coordenada (x,y) tengan el mismo valor que el argmax
#Se le suma +1 porque Python cuenta a veces de 0 y a veces de 1. No discutire por algo asi, pero esta feo.
f_max_bin = (binnumber[0,:] == argmax[0]+1) & (binnumber[1,:] == argmax[1]+1)
#Filtramos por la condicion anterior para tener solo las estrellas del maximo
x_max = x[f_max_bin]
y_max = y[f_max_bin]
for i, bines in enumerate(set(compressed_binnumber)):
plt.scatter(x[compressed_binnumber == bines],y[compressed_binnumber == bines], s=9)
for i, edges in enumerate(x_edges):
plt.vlines(x_edges[i], np.min(y), np.max(y), color='r', alpha=0.4)
for i, edges in enumerate(y_edges):
plt.hlines(y_edges[i], np.min(x), np.max(x), color='r', alpha=0.4)
plt.scatter(x_max, y_max, color='k', s=9)
plt.show()
#Comparamos con lo que daria el histograma2d
#Cuidado con imshow que normalmente tiene el inicio de los ejes en upper, lo que hace una rotacion no deseada, origin='lower' lo arregla
plt.imshow(stat, extent=[np.min(x), np.max(x), np.min(y), np.max(y)], aspect='auto', origin='lower')
plt.show()
print('Nr. estrellas en max_bin: %d' %np.max(stat))
print('Nr. estrellas en la seleccion.x %d' %x_max.size)
print('Nr. estrellas en la seleccion.y %d' %y_max.size)
#Y ahora los dos juntos
plt.imshow(stat, extent=[np.min(x), np.max(x), np.min(y), np.max(y)], aspect='auto', origin='lower')
plt.scatter(x_max, y_max, color='k', s=9)
plt.show()
#Corroboramos con la matriz stat
stat