Major corrections to gradient descent code and associated plotting.

This commit is contained in:
Karthik 2025-04-28 11:46:54 +02:00
parent cf5d3fd209
commit b47d81e68d
2 changed files with 101 additions and 34 deletions

View File

@ -15,7 +15,7 @@ function plotLive(psi,Params,Transf,Observ,SimulationMode)
n = abs(psi).^2; n = abs(psi).^2;
switch SimulationMode switch SimulationMode
case {'ImaginaryTimeEvolution', 'EnergyMinimization'} case 'ImaginaryTimeEvolution'
%Plotting %Plotting
fig = figure(1); fig = figure(1);
fig.WindowState = 'maximized'; fig.WindowState = 'maximized';
@ -76,6 +76,68 @@ function plotLive(psi,Params,Transf,Observ,SimulationMode)
ylabel('$\mu$', 'FontSize', 14); xlabel('Time steps', 'FontSize', 14); ylabel('$\mu$', 'FontSize', 14); xlabel('Time steps', 'FontSize', 14);
title('Chemical Potential', 'FontSize', 14); title('Chemical Potential', 'FontSize', 14);
grid on grid on
case 'EnergyMinimization'
%Plotting
fig = figure(1);
fig.WindowState = 'maximized';
clf
% set(gcf,'Position', [100, 100, 1600, 900])
t = tiledlayout(2, 3, 'TileSpacing', 'compact', 'Padding', 'compact'); % 2x3 grid
nexttile;
nxz = squeeze(trapz(n*dy,2));
nyz = squeeze(trapz(n*dx,1));
nxy = squeeze(trapz(n*dz,3));
plotxz = pcolor(x,z,nxz');
set(plotxz, 'EdgeColor', 'none');
cbar1 = colorbar;
cbar1.Label.Interpreter = 'latex';
% cbar1.Ticks = []; % Disable the ticks
colormap(gca, Helper.Colormaps.plasma())
xlabel('$x$ ($\mu$m)', 'Interpreter', 'latex', 'FontSize', 14)
ylabel('$z$ ($\mu$m)', 'Interpreter', 'latex', 'FontSize', 14)
title('$|\Psi(x,z)|^2$', 'Interpreter', 'latex', 'FontSize', 14)
nexttile;
plotyz = pcolor(y,z,nyz');
set(plotyz, 'EdgeColor', 'none');
cbar1 = colorbar;
cbar1.Label.Interpreter = 'latex';
% cbar1.Ticks = []; % Disable the ticks
colormap(gca, Helper.Colormaps.plasma())
xlabel('$y$ ($\mu$m)', 'Interpreter', 'latex', 'FontSize', 14)
ylabel('$z$ ($\mu$m)', 'Interpreter', 'latex', 'FontSize', 14)
title('$|\Psi(y,z)|^2$', 'Interpreter', 'latex', 'FontSize', 14)
nexttile;
plotxy = pcolor(x,y,nxy');
set(plotxy, 'EdgeColor', 'none');
cbar1 = colorbar;
cbar1.Label.Interpreter = 'latex';
% cbar1.Ticks = []; % Disable the ticks
colormap(gca, Helper.Colormaps.plasma())
xlabel('$x$ ($\mu$m)', 'Interpreter', 'latex', 'FontSize', 14)
ylabel('$y$ ($\mu$m)', 'Interpreter', 'latex', 'FontSize', 14)
title('$|\Psi(x,y)|^2$', 'Interpreter', 'latex', 'FontSize', 14)
nexttile;
plot(Observ.theta,'-b')
ylabel('$\theta_{opt}$', 'FontSize', 14); xlabel('Time steps', 'FontSize', 14);
title('Optimal angle', 'FontSize', 14);
grid on
nexttile;
plot(Observ.EVec,'-b')
ylabel('$E_{tot}$', 'FontSize', 14); xlabel('Time steps', 'FontSize', 14);
title('Total Energy', 'FontSize', 14);
grid on
nexttile;
plot(Observ.mucVec,'-b')
ylabel('$\mu$', 'FontSize', 14); xlabel('Time steps', 'FontSize', 14);
title('Chemical Potential', 'FontSize', 14);
grid on
case 'RealTimeEvolution' case 'RealTimeEvolution'

View File

@ -94,6 +94,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
i = 1; i = 1;
Observ.residual = 1; Observ.residual = 1;
Observ.res = 1; Observ.res = 1;
Observ.theta = 0;
% Initialize the PrematureExitFlag to false % Initialize the PrematureExitFlag to false
PrematureExitFlag = false; PrematureExitFlag = false;
@ -109,24 +110,23 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Compute gradient % Compute gradient
J = compute_gradient(psi, Params, Transf, VDk, V); J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential % Calculate chemical potential
muchem = real(psi(:)' * J(:)) / sum(abs(psi(:)).^2); muchem = real(inner_product(psi, J, Transf)) / inner_product(psi, psi, Transf);
% Calculate residual % Calculate residual
residual = J - (muchem * psi); residual = J - (muchem * psi);
% Compute energy difference between the last two saved energy values % Energy convergence check every 100 steps
if i == 1 if mod(i,100) == 0 && length(Observ.EVec) > 1
energydifference = NaN;
elseif mod(i,100) == 0 && length(Observ.EVec) > 1
energydifference = abs(Observ.EVec(end) - Observ.EVec(end-1)); energydifference = abs(Observ.EVec(end) - Observ.EVec(end-1));
if energydifference <= epsilon
disp('Tolerance reached: Energy difference is below the specified epsilon.');
PrematureExitFlag = true;
break;
end
end end
% Convergence check - if energy difference is below set tolerance, then exit
if energydifference <= epsilon if i >= this.MaxIterationsForGD
disp('Tolerance reached: Energy difference is below the specified epsilon.');
PrematureExitFlag = true; % Set flag to indicate premature exit
break;
elseif i >= this.MaxIterationsForGD % If set maximum number of iterations reached, then exit
disp('Maximum number of iterations for CGD reached.'); disp('Maximum number of iterations for CGD reached.');
PrematureExitFlag = true; % Set flag to indicate premature exit PrematureExitFlag = true;
break; break;
end end
@ -135,33 +135,34 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
d = -residual; d = -residual;
else % Compute beta via Polak-Ribiere and create new direction else % Compute beta via Polak-Ribiere and create new direction
residual_new = residual; residual_new = residual;
beta = compute_beta(residual_new, residual_old); beta = compute_beta(residual_new, residual_old, Transf);
d = -residual_new + beta * p_old; d = -residual_new + beta * p_old;
end end
residual_old = residual; residual_old = residual;
p = d - (real(d(:)' * psi(:) * Transf.dx * Transf.dy * Transf.dz) .* psi); proj = inner_product(d, psi, Transf);
p = d - (proj * psi);
p_old = p; p_old = p;
% Compute optimal theta to generate psi in the direction of minimum in the energy landscape % Compute optimal theta to generate psi in the direction of minimum in the energy landscape
theta = compute_optimal_theta(p, muchem, psi, Params, Transf, VDk, V); theta = compute_optimal_theta(p, muchem, psi, Params, Transf, VDk, V);
% Update solution % Update solution
gamma = 1 / sqrt(sum(abs(p(:)).^2) * Transf.dx * Transf.dy * Transf.dz); gamma = 1 / sqrt(inner_product(p, p, Transf));
psi = (cos(theta).*psi) + (gamma*sin(theta).*p); psi = (cos(theta).*psi) + (sin(theta).*(p*gamma));
% Normalize psi % Normalize psi
Norm = sum(abs(psi(:)).^2) * Transf.dx * Transf.dy * Transf.dz; psi = sqrt(Params.N) * psi / sqrt(abs(inner_product(psi, psi, Transf)));
psi = sqrt(Params.N) * psi / sqrt(Norm);
i = i + 1; i = i + 1;
% Calculate chemical potential with new psi % Calculate chemical potential with new psi
muchem = real(psi(:)' * J(:)) / sum(abs(psi(:)).^2); J = compute_gradient(psi, Params, Transf, VDk, V);
muchem = real(inner_product(psi, J, Transf)) / inner_product(psi, psi, Transf);
if mod(i,100) == 0 if mod(i,100) == 0
% Collect Energy value % Collect Energy value
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V); E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / inner_product(psi, psi, Transf);
E = E/Norm;
Observ.EVec = [Observ.EVec E]; Observ.EVec = [Observ.EVec E];
% Collect Chemical potential value % Collect Chemical potential value
@ -171,13 +172,15 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
res = this.Calculator.calculateNormalizedResiduals(psi,Params,Transf,VDk,V,muchem); res = this.Calculator.calculateNormalizedResiduals(psi,Params,Transf,VDk,V,muchem);
Observ.residual = [Observ.residual res]; Observ.residual = [Observ.residual res];
Observ.res_idx = Observ.res_idx + 1; Observ.res_idx = Observ.res_idx + 1;
% Collect calculated thetas
Observ.theta = [Observ.theta theta];
% Live plotter % Live plotter
if this.PlotLive if this.PlotLive
Plotter.plotLive(psi,Params,Transf,Observ,this.SimulationMode) Plotter.plotLive(psi,Params,Transf,Observ,this.SimulationMode)
drawnow drawnow
end end
save(sprintf(strcat(this.SaveDirectory, '/Run_%03i/psi_gs.mat'),Params.njob),'psi','muchem','Observ','Transf','Params','VDk','V'); save(sprintf(strcat(this.SaveDirectory, '/Run_%03i/psi_gs.mat'),Params.njob),'psi','muchem','Observ','Transf','Params','VDk','V');
end end
end end
@ -190,13 +193,11 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
end end
% Change in Energy % Change in Energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V); E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / inner_product(psi, psi, Transf);
E = E/Norm;
Observ.EVec = [Observ.EVec E]; Observ.EVec = [Observ.EVec E];
disp('Saving data...');
save(sprintf(strcat(this.SaveDirectory, '/Run_%03i/psi_gs.mat'),Params.njob),'psi','muchem','Observ','Transf','Params','VDk','V'); save(sprintf(strcat(this.SaveDirectory, '/Run_%03i/psi_gs.mat'),Params.njob),'psi','muchem','Observ','Transf','Params','VDk','V');
disp('Save complete!'); disp('Completed and saved!');
end end
end end
@ -261,15 +262,19 @@ function theta = compute_optimal_theta(p, muchem, psi, Params, Transf, VDk, V)
Hpsi = compute_gradient(psi, Params, Transf, VDk, V); Hpsi = compute_gradient(psi, Params, Transf, VDk, V);
Hp = compute_gradient(p, Params, Transf, VDk, V); Hp = compute_gradient(p, Params, Transf, VDk, V);
g = compute_g(psi, p, Params, VDk); g = compute_g(psi, p, Params, VDk);
gamma = 1 / sqrt(sum(abs(p(:)).^2) * Transf.dx * Transf.dy * Transf.dz); gamma = 1 / sqrt(inner_product(p, p, Transf));
numerator = gamma * real(p(:)' * Hpsi(:) * Transf.dx * Transf.dy * Transf.dz); numerator = gamma * real(inner_product(p, Hpsi, Transf));
denominator = muchem - (gamma^2 * (real(p(:)' * Hp(:) * Transf.dx * Transf.dy * Transf.dz) + real(g(:)' * p(:) * Transf.dx * Transf.dy * Transf.dz))); denominator = muchem - (gamma^2 * (real(inner_product(p, Hp, Transf)) + real(inner_product(g, p, Transf))));
theta = numerator / denominator; theta = numerator / denominator;
end end
% Optimal step size via Polak-Ribiere % Optimal step size via Polak-Ribiere
function beta = compute_beta(residual_new, residual_old) function beta = compute_beta(residual_new, residual_old, Transf)
beta = (residual_new(:)' * (residual_new(:) - residual_old(:))) / sum(abs(residual_old(:)).^2); beta = inner_product(residual_new, (residual_new - residual_old), Transf) / inner_product(residual_old, residual_old, Transf);
beta = max(0,beta); beta = max(0, real(beta));
end end
function s = inner_product(u, v, Transf)
s = sum(conj(u(:)) .* v(:)) * Transf.dx * Transf.dy * Transf.dz;
end