import numpy as np
import math
try:
    from . import matlabport
except ImportError:
    import matlabport

def elev_azim(x, y, z):
    hor = math.sqrt(x**2 + y**2)

    elev = math.degrees(math.atan2(z, hor))
    azim = math.degrees(math.atan2(y, x))
    return elev, azim

def R1(angle):
    return np.array([
      [1, 0, 0],
      [0, np.cos(angle), np.sin(angle)],
      [0, -np.sin(angle), np.cos(angle)]
    ])

def R2(angle):
    return np.array([
        [np.cos(angle), 0, -np.sin(angle)],
        [0, 1, 0],
        [np.sin(angle), 0, np.cos(angle)]
    ])

def R3(angle):
    return np.array([
        [np.cos(angle), np.sin(angle), 0],
        [-np.sin(angle), np.cos(angle), 0],
        [0, 0, 1]
    ])

def get_N(epsilon, delta_epsilon, delta_psi):
    return R1(-epsilon - delta_epsilon) @ (R3(-delta_psi) @ R1(epsilon))

def get_P(z, Theta, zeta_0):
    return R3(-z) @ (R2(Theta) @ R3(-zeta_0))

def get_S1():
    return [
        [-1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
    ]

def get_r_g(Phi, Lambda, GAST, N, P, r_i0):
    return get_S1() @ (R2(np.pi/2 - Phi) @ (R3(Lambda) @ (R3(GAST) @ (N @ (P @ r_i0)))))

def get_r_g_total(Phi, Lambda, GAST, epsilon, delta_epsilon, delta_psi, z, Theta, zeta_0, r_i0):
    #print(GAST)
    N = get_N(epsilon, delta_epsilon, delta_psi)
    P = get_P(z, Theta, zeta_0)
    return get_r_g(Phi, Lambda, GAST, N, P, r_i0)

def to_local(x_array, y_array, z_array, lat, lon, jd_array, fr_array):
    """
    Convert arrays of x, y and z coordinates in inertial frame to local geodetic coordinates.
    lat and lon describe position of local geodetic coordinate system.
    jd and fr are julian day and fractional part (jd + fr is actual time).
    """
    a = 6378.137
    e = 0
    N = a / math.sqrt(1 - ((e**2) * (math.sin(lat)**2)))
    x_obs_l = N * math.cos(lat) * math.cos(lon)
    y_obs_l = N * math.cos(lat) * math.sin(lon)
    z_obs_l = N * (1 - (e**2)) * math.sin(lat)
    x_local, y_local, z_local = [], [], []
    elev, azim = [], []
    for x, y, z, jd, fr in zip(x_array, y_array, z_array, jd_array, fr_array):
        gast, _ = matlabport.jul2gast(fr*24, jd/36525)
        print(gast)
        gast_rad = math.radians(gast * 15)

        t = jd/36525

        print(t)

        epsilon, delta_epsilon, delta_psi = matlabport.nutwink(t)
        z_prez, Theta, zeta_0 = matlabport.prezwink(t)
        epsilon = math.radians(epsilon)
        delta_epsilon = math.radians(delta_epsilon)
        delta_psi = math.radians(delta_psi)
        z_prez = math.radians(z_prez)
        Theta = math.radians(Theta)
        zeta_0 = math.radians(zeta_0)

        N = get_N(epsilon, delta_epsilon, delta_psi)
        P = get_P(z_prez, Theta, zeta_0)
        x_l, y_l, z_l = get_r_g(lat, lon, gast_rad, N, P, np.array([x, y, z]))
        
        x_l -= x_obs_l
        y_l -= y_obs_l
        z_l -= z_obs_l
            
        x_local.append(x_l)
        y_local.append(y_l)
        z_local.append(z_l)

        e, a = elev_azim(x_l, y_l, z_l)

        elev.append(e)
        azim.append(a)
    return np.array(x_local), np.array(y_local), np.array(z_local), np.array(elev), np.array(azim)

