%% Define a beam and visualize it propagating
% Define "optical system" using matrices. For this first example, we will just use a free space propagation matrix, and let the beam propagate a distance

w0  = 1E-3;         % 1mm beam waist
lam = 421E-9;       % wavelength
zR  = Zr(w0, lam);  % Rayleigh range in m
z0  = 0;            % location of waist in m

syms d;             % Define 'd' as a symbolic variable
M = propagator(d); 

[R, w] = q1_inv_func(0, w0, lam, M); %  Do all the ABCD and q-parameter math, and return the waist and radius of curvature functions

plotResult(w, d, linspace(0,10))

%% Add a lens to the system

w0  = 1E-3;         % 1mm beam waist
lam = 421E-9;       % wavelength
zR  = Zr(w0, lam);  % Rayleigh range in m
z0  = 0;            % location of waist in m

syms d;             % Define 'd' as a symbolic variable
M = matmultiplier(propagator(d), lens(.5), propagator(1));
            
[R, w] = q1_inv_func(0, w0, lam, M);

plotResult(w, d, linspace(0,1,50))

%% Add another to make a beam expander, that enlarges the beam waist by a factor 5 and then collimates it

w0  = 1E-3;         % 1mm beam waist
lam = 421E-9;       % wavelength 
zR  = Zr(w0, lam);  % Rayleigh range in m
z0  = 0;            % location of waist in m
expansion_factor = 5;

syms d1 d2 d3 f1 f2

M = matmultiplier(propagator(d3), lens(f2), propagator(d2), lens(f1), propagator(d1));

[R, w] = q1_inv_func(0, w0, lam, M);

% Substitute values into w
w      = subs(w, {d1, d3}, {1, 0});

% Solve for f1
f1_eq = solve(w - expansion_factor*w0 == 0, f1);
f1_eq_subs          = subs(f1_eq, d2, 1);
% Convert solutions to numeric values
f1_val_numeric      = double(f1_eq_subs);
fprintf('f1 = %.2f m or %.2f m, for a lens separation of 1 meter\n', f1_val_numeric(1), f1_val_numeric(2))

for i = 1:length(f1_val_numeric)
    
    f1_val = f1_val_numeric(i);
    R      = subs(R, {d1, d2, d3, f1}, {1, 1, 0, f1_val});
    f2_val  = solve(1/R == 0, f2);
    fprintf('f2 = %.2f, for a collimated beam, 5x the original waist, after propagating 1m to the first lens of f1 = %.2f, and propagating another 1m to the second lens\n', double(f2_val), double(f1_val))
end

chosen_f1_val = f1_val_numeric(1);

fprintf('Chosen f1 = %.2f m \n', chosen_f1_val)

M = matmultiplier(propagator(d3), lens(f2_val), propagator(1), lens(chosen_f1_val), propagator(1));

[R, w] = q1_inv_func(0, w0, lam, M);

plotResult(w,d3)

estimated_expansion_factor = subs(w, d3, 0)/ w0;
fprintf('beam is w = %.2f x w0\n', double(estimated_expansion_factor))

beam_size_change = ((subs(w,d3,10) - subs(w,d3,0)) / subs(w,d3,0)) * 100;
fprintf('Over 10 m after second lens, beam changes by %.5f percent\n', double(beam_size_change)) 

%% Function definitions

% Input Ray parameter, i.e. height and angle
function ret = ray(y, theta)
    % Parameters
    % ----------
    % y : float or integer in meters
    %     The vertical height of a ray.
    % theta : float or integer in radians
    %     The angle of divergence of the ray.
    %
    % Returns
    % -------
    % mat : 2x1 matrix
    %     [
    %     [y],
    %     [theta]
    %     ]

    ret = [y; theta];
end

% Ray Transfer Matrix for ideal lens with focal length f
function ret = lens(f)
    % Parameters
    % ----------
    % f : float or integer in meters
    %     Thin lens focal length in meters
    %
    % Returns
    % -------
    % mat : 2x2 matrix
    % [
    % [   1, 0],
    % [-1/f, 1]
    % ]

    ret = [1, 0; -1/f, 1];
end

% Ray Transfer Matrix for propagation over a distance d
function ret = propagator(d)
    % Parameters
    % ----------
    % d : float or integer
    %     Distance light is propagating along the z-axis.
    %
    % Returns
    % -------
    % mat: 2x2 matrix
    % [
    % [1, d],
    % [0, 1]
    % ]

    ret = [1, d; 0, 1];
end

% Multiplying the matrices together. mat1 is the last matrix the light interacts with.
function ret = matmultiplier(mat, varargin)
    % Parameters
    % ----------
    % mat1 : 2x2 ABCD matrix
    %     Last matrix light interacts with.
    % varargin : 2x2 ABCD matrices
    %     From left to right, the matrices should be entered such that the leftmost matrix interacts
    %     with light temporally after the rightmost matrix.
    %
    % Returns
    % -------
    % Mat : 2x2 matrix
    %     The ABCD matrix describing the whole optical system.
    
    ret = mat;
    for i = 1:length(varargin)
        ret = ret * varargin{i};
    end
end

% Adding Gaussian beam parameters
function zr = Zr(wo, lam)
    % Parameters
    % ----------
    % wo : float or integer
    %     Beam waist radius in meters.
    % lam : float or integer
    %     Wavelength of light in meters.
    %
    % Returns
    % -------
    % zr : float or integer
    %     Rayleigh range for given beam waist and wavelength.

    zr = pi * wo^2 / lam;
end

% Calculate beam waist radius
function w0 = W0(zr, lam)
    % Parameters
    % ----------
    % zr : float or integer
    %     Rayleigh range in meters.
    % lam : float or integer
    %     Wavelength of light in meters.
    %
    % Returns
    % -------
    % w0 : float or integer
    %     Beam waist radius in meters.

    w0 = sqrt(lam * zr / pi);
end

function [z_out, zr_out] = q1_func(z, w0, lam, mat)
    % Parameters
    % ----------
    % z : float or integer
    %     Position of the beam waist in meters.
    % w0 : float or integer
    %     Radial waist size in meters (of the embedded Gaussian, i.e. W0/M).
    % lam : float or integer
    %     Wavelength of light in meters.
    % mat : 2x2 matrix
    %     The ABCD matrix describing the optical system.
    %
    % Returns
    % -------
    % z_out : float or integer
    %     Position of the beam waist after the optical system.
    % zr_out : float or integer
    %     Rayleigh range of the beam after the optical system.

    A = mat(1, 1);
    B = mat(1, 2);
    C = mat(2, 1);
    D = mat(2, 2);
    
    % Calculate Rayleigh range for the given beam waist and wavelength
    zr = Zr(w0, lam);

    % Calculate real and imaginary parts
    real_part = (A * C * (z^2 + zr^2) + z * (A * D + B * C) + B * D) / (C^2 * (z^2 + zr^2) + 2 * C * D * z + D^2);
    imag_part = (zr * (A * D - B * C)) / (C^2 * (z^2 + zr^2) + 2 * C * D * z + D^2);

    % Output the new position and Rayleigh range
    z_out = real_part;
    zr_out = imag_part;
end

function [R, w] = q1_inv_func(z, w0, lam, mat)
    % Parameters
    % ----------
    % z : float or integer
    %     Position of the beam waist in meters.
    % w0 : float or integer
    %     Radial waist size in meters (of the embedded Gaussian, i.e. W0/M).
    % lam : float or integer
    %     Wavelength of light in meters.
    % mat : 2x2 matrix
    %     The ABCD matrix describing the optical system.
    %
    % Returns
    % -------
    % R : float or integer
    %     Radius of curvature of the wavefront in meters.
    % w : float or integer
    %     Radius of the beam in meters.

    A = mat(1, 1);
    B = mat(1, 2);
    C = mat(2, 1);
    D = mat(2, 2);
    
    % Calculate Rayleigh range for the given beam waist and wavelength
    zr = Zr(w0, lam);

    % Calculate real and imaginary parts
    real_part = (A * C * (z^2 + zr^2) + z * (A * D + B * C) + B * D) / (A^2 * (z^2 + zr^2) + 2 * A * B * z + B^2);
    imag_part = -zr * ((A * D) - (B * C)) / (A^2 * (z^2 + zr^2) + (2 * A * B * z) + B^2);
    
    % Calculate radius of curvature and beam radius
    R = 1/real_part;
    w = sqrt(-lam / imag_part / pi);    
end

function plotResult(func, var, rang)
    % Parameters
    % ----------
    % func : symfun (symbolic function of one variable)
    %     Symbolic function defining the beam width after the last optical element.
    % var : symbolic variable
    %     Variable in func that will be plotted.
    % rang : array (optional)
    %     Array of values along the optical axis to be plotted. Default range is 0 to 3 with step size 0.01.

    % Set default range if not provided
    if nargin < 3
        rang = 0:0.01:3;
    end

    % Create a numeric function from the symbolic function
    func_handle = matlabFunction(func, 'Vars', var);
    
    % Compute the values of the function and its negative
    y_vals = func_handle(rang) * 1E3; % in mm
    y_neg_vals = -y_vals;

    % Create a figure
    figure;
    set(gcf, 'Position', [100 100 950 750]);
    set(gca,'FontSize',16,'Box','On','Linewidth',2);
    % Plot the function and its negative
    plot(rang, y_vals, 'b'); hold on;
    plot(rang, y_neg_vals, 'b');

    % Fill the area between the curves with translucent blue color
    fill([rang, fliplr(rang)], [y_vals, fliplr(y_neg_vals)], 'b', 'FaceAlpha', 0.2, 'EdgeColor', 'none');

    % Add grid and labels
    grid on;
    xlabel('Optical Axis (m)', 'FontSize', 16);
    ylabel('Beam size (mm)', 'FontSize', 16);
end