Renamed script of gradient descent.

This commit is contained in:
Karthik 2025-04-01 01:09:34 +02:00
parent 86b74b30da
commit f1b781f8b8
3 changed files with 7 additions and 6 deletions

View File

@ -23,7 +23,7 @@ OptionsStruct.UseApproximationForLHY = true;
OptionsStruct.IncludeDDICutOff = true;
OptionsStruct.CutoffType = 'Cylindrical';
OptionsStruct.SimulationMode = 'EnergyMinimization'; % 'ImaginaryTimeEvolution' | 'RealTimeEvolution' | 'EnergyMinimization'
OptionsStruct.MaxIterationsForCGD = 1E3;
OptionsStruct.MaxIterationsForGD = 1E5;
OptionsStruct.TimeStepSize = 1E-4; % in s
OptionsStruct.MinimumTimeStepSize = 2E-10; % in s
OptionsStruct.TimeCutOff = 2E6; % in s

View File

@ -20,7 +20,7 @@ function [Params, Transf, psi,V,VDk] = run(this)
mkdir(sprintf(this.SaveDirectory))
mkdir(sprintf(strcat(this.SaveDirectory, '/Run_%03i'),Params.njob))
if strcmp(this.SimulationMode, 'EnergyMinimization')
[psi] = this.minimizeEnergyFunctional(psi,Params,Transf,VDk,V,Observ);
[psi] = this.runGradientDescent(psi,Params,Transf,VDk,V,Observ);
else
[psi] = this.propagateWavefunction(psi,Params,Transf,VDk,V,t_idx,Observ);
end

View File

@ -8,6 +8,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
beta = 0.9;
Observ.residual = 1;
Observ.res = 1;
psi_old = psi; % Previous psi value (for heavy-ball method)
if this.PlotLive
@ -22,10 +23,10 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential and norm
muchem = sum(real(conj(psi) .* J)) / sum(abs(psi).^2);
muchem = sum(real(conj(psi(:)) .* J(:))) / sum(abs(psi(:)).^2);
% Calculate residual and check convergence
residual = sum(abs(J - muchem * psi).^2) * Transf.dx * Transf.dy * Transf.dz;
residual = sum(abs(J(:) - muchem * psi(:)).^2) * Transf.dx * Transf.dy * Transf.dz;
if residual < epsilon
fprintf('Convergence reached at iteration %d\n', idx);
break;
@ -39,7 +40,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
Norm = sum(abs(psi_new(:)).^2) * Transf.dx * Transf.dy * Transf.dz;
psi = sqrt(Params.N) * psi_new / sqrt(Norm);
if mod(idx,500) == 0
if mod(idx,100) == 0
% Collect change in energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V);
@ -50,7 +51,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
Observ.mucVec = [Observ.mucVec muchem];
% Collect residuals
Observ.residual = [Observ.residual res];
Observ.residual = [Observ.residual residual];
Observ.res_idx = Observ.res_idx + 1;
if this.PlotLive