MAJOR corrections to gradient descent code

This commit is contained in:
Karthik 2025-04-25 15:32:58 +02:00
parent 4918ee1ec0
commit eabf2c1a75
2 changed files with 28 additions and 22 deletions

View File

@ -15,11 +15,12 @@ function plotLive(psi,Params,Transf,Observ,SimulationMode)
n = abs(psi).^2; n = abs(psi).^2;
switch SimulationMode switch SimulationMode
case 'ImaginaryTimeEvolution' case {'ImaginaryTimeEvolution', 'EnergyMinimization'}
%Plotting %Plotting
figure(1); fig = figure(1);
fig.WindowState = 'maximized';
clf clf
set(gcf,'Position', [100, 100, 1600, 900]) % set(gcf,'Position', [100, 100, 1600, 900])
t = tiledlayout(2, 3, 'TileSpacing', 'compact', 'Padding', 'compact'); % 2x3 grid t = tiledlayout(2, 3, 'TileSpacing', 'compact', 'Padding', 'compact'); % 2x3 grid
nexttile; nexttile;
@ -79,9 +80,10 @@ function plotLive(psi,Params,Transf,Observ,SimulationMode)
case 'RealTimeEvolution' case 'RealTimeEvolution'
%Plotting %Plotting
figure(1); fig = figure(1);
fig.WindowState = 'maximized';
clf clf
set(gcf,'Position', [100, 100, 1600, 900]) % set(gcf,'Position', [100, 100, 1600, 900])
t = tiledlayout(2, 3, 'TileSpacing', 'compact', 'Padding', 'compact'); % 2x3 grid t = tiledlayout(2, 3, 'TileSpacing', 'compact', 'Padding', 'compact'); % 2x3 grid
nexttile; nexttile;

View File

@ -14,7 +14,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Live Plotter % Live Plotter
if this.PlotLive if this.PlotLive
Plotter.plotLive(psi,Params,Transf,Observ) Plotter.plotLive(psi,Params,Transf,Observ,this.SimulationMode)
drawnow drawnow
end end
@ -24,11 +24,12 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Compute gradient % Compute gradient
J = compute_gradient(psi, Params, Transf, VDk, V); J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential and norm % Calculate chemical potential
muchem = sum(real(conj(psi(:)) .* J(:))) / sum(abs(psi(:)).^2); muchem = sum(real(psi(:)' .* J(:))) / sum(abs(psi(:)).^2);
% Calculate residual and check convergence % Calculate residual and check convergence
residual = sum(abs(J(:) - (muchem * psi(:))).^2) * Transf.dx * Transf.dy * Transf.dz; residual = sum(abs(J(:) - (muchem * psi(:))).^2) * Transf.dx * Transf.dy * Transf.dz;
if residual < epsilon if residual < epsilon
fprintf('Convergence reached at iteration %d\n', idx); fprintf('Convergence reached at iteration %d\n', idx);
break; break;
@ -58,7 +59,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
Observ.res_idx = Observ.res_idx + 1; Observ.res_idx = Observ.res_idx + 1;
if this.PlotLive if this.PlotLive
Plotter.plotLive(psi,Params,Transf,Observ) Plotter.plotLive(psi,Params,Transf,Observ,this.SimulationMode)
drawnow drawnow
end end
@ -87,7 +88,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
case 'NonLinearCGD' case 'NonLinearCGD'
% Convergence Criteria: % Convergence Criteria:
epsilon = 1E-13; epsilon = 1E-14;
% Iteration Counter: % Iteration Counter:
i = 1; i = 1;
@ -99,7 +100,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Live plotter % Live plotter
if this.PlotLive if this.PlotLive
Plotter.plotLive(psi,Params,Transf,Observ) Plotter.plotLive(psi,Params,Transf,Observ,this.SimulationMode)
drawnow drawnow
end end
@ -108,7 +109,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Compute gradient % Compute gradient
J = compute_gradient(psi, Params, Transf, VDk, V); J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential % Calculate chemical potential
muchem = real(conj(psi(:))' * J(:)) / norm(psi(:))^2; muchem = real(psi(:)' * J(:)) / sum(abs(psi(:)).^2);
% Calculate residual % Calculate residual
residual = J - (muchem * psi); residual = J - (muchem * psi);
@ -139,21 +140,22 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
end end
residual_old = residual; residual_old = residual;
p = d - (real(conj(d(:))' * psi(:)) .* psi); p = d - (real(d(:)' * psi(:) * Transf.dx * Transf.dy * Transf.dz) .* psi);
p_old = p; p_old = p;
% Compute optimal theta to generate psi in the direction of minimum in the energy landscape % Compute optimal theta to generate psi in the direction of minimum in the energy landscape
theta = compute_optimal_theta(p, muchem, psi, Params, Transf, VDk, V); theta = compute_optimal_theta(p, muchem, psi, Params, Transf, VDk, V);
% Update solution % Update solution
psi = (cos(theta).*psi) + (sin(theta).*(p / norm(p(:)))); gamma = 1 / sqrt(sum(abs(p(:)).^2));
psi = (cos(theta).*psi) + (gamma*sin(theta).*p);
% Normalize psi % Normalize psi
Norm = sum(abs(psi(:)).^2) * Transf.dx * Transf.dy * Transf.dz; Norm = sum(abs(psi(:)).^2) * Transf.dx * Transf.dy * Transf.dz;
psi = sqrt(Params.N) * psi / sqrt(Norm); psi = sqrt(Params.N) * psi / sqrt(Norm);
i = i + 1; i = i + 1;
% Calculate chemical potential with new psi % Calculate chemical potential with new psi
muchem = real(conj(psi(:))' * J(:)) / norm(psi(:))^2; muchem = real(psi(:)' * J(:)) / sum(abs(psi(:)).^2);
if mod(i,100) == 0 if mod(i,100) == 0
@ -172,7 +174,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Live plotter % Live plotter
if this.PlotLive if this.PlotLive
Plotter.plotLive(psi,Params,Transf,Observ) Plotter.plotLive(psi,Params,Transf,Observ,this.SimulationMode)
drawnow drawnow
end end
@ -206,7 +208,7 @@ function J = compute_gradient(psi, Params, Transf, VDk, V)
% Kinetic energy % Kinetic energy
KEop = 0.5 * (Transf.KX.^2+Transf.KY.^2+Transf.KZ.^2); KEop = 0.5 * (Transf.KX.^2+Transf.KY.^2+Transf.KZ.^2);
HKin = @(w) real(ifft(KEop.*fft(w))); HKin = @(w) ifft(KEop.*fft(w));
% Trap Potential % Trap Potential
HV = @(w) V.*w; HV = @(w) V.*w;
@ -216,7 +218,7 @@ function J = compute_gradient(psi, Params, Transf, VDk, V)
% DDIs % DDIs
frho = fftn(abs(psi).^2); frho = fftn(abs(psi).^2);
Phi = real(ifftn(frho.*VDk)); Phi = ifftn(frho.*VDk);
Hddi = @(w) (Params.gdd*Phi).*w; Hddi = @(w) (Params.gdd*Phi).*w;
% Quantum fluctuations % Quantum fluctuations
@ -239,7 +241,7 @@ function g = compute_g(psi, p, Params, VDk)
% DDIs % DDIs
D = Params.gdd*Params.N; D = Params.gdd*Params.N;
rhotilde = fftn(rho); rhotilde = fftn(rho);
Phi = real(ifftn(rhotilde.*VDk)); Phi = ifftn(rhotilde.*VDk);
gaddi = @(w)(D.*Phi).*w; gaddi = @(w)(D.*Phi).*w;
% Quantum fluctuations % Quantum fluctuations
@ -259,13 +261,15 @@ function theta = compute_optimal_theta(p, muchem, psi, Params, Transf, VDk, V)
Hpsi = compute_gradient(psi, Params, Transf, VDk, V); Hpsi = compute_gradient(psi, Params, Transf, VDk, V);
Hp = compute_gradient(p, Params, Transf, VDk, V); Hp = compute_gradient(p, Params, Transf, VDk, V);
g = compute_g(psi, p, Params, VDk); g = compute_g(psi, p, Params, VDk);
numerator = real(conj(p(:))' * Hpsi(:))/norm(p(:)); gamma = 1 / sqrt(sum(abs(p(:)).^2));
denominator = muchem - ((conj(p(:))' * Hp(:)) + real(conj(g(:))' * p(:)))/norm(p(:))^2; numerator = real(p(:)' * Hpsi(:) * Transf.dx * Transf.dy * Transf.dz)/gamma;
denominator = muchem - (gamma^2 * (real(p(:)' * Hp(:) * Transf.dx * Transf.dy * Transf.dz) + real(g(:)' * p(:) * Transf.dx * Transf.dy * Transf.dz)));
theta = numerator / denominator; theta = numerator / denominator;
end end
% Optimal step size via Polak-Ribiere % Optimal step size via Polak-Ribiere
function beta = compute_beta(residual_new, residual_old) function beta = compute_beta(residual_new, residual_old)
beta = max(0, (residual_new(:)' * (residual_new(:) - residual_old(:))) / (norm(residual_old(:))^2)); beta = (residual_new(:)' * (residual_new(:) - residual_old(:))) / sum(abs(residual_old(:)).^2);
beta = max(0,beta);
end end