Minor corrections - theta value now seems to decrease beyond 10 iterations but the resulting state is still incorrect.

This commit is contained in:
Karthik 2025-04-28 16:15:40 +02:00
parent b47d81e68d
commit 58566397ae

View File

@ -22,13 +22,14 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
for idx = 1:this.MaxIterationsForGD for idx = 1:this.MaxIterationsForGD
% Compute gradient % Compute gradient
J = compute_gradient(psi, Params, Transf, VDk, V); J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential % Calculate chemical potential
muchem = sum(real(psi(:)' .* J(:))) / sum(abs(psi(:)).^2); muchem = real(inner_product(J, psi, Transf)) / inner_product(psi, psi, Transf);
% Calculate residual and check convergence % Calculate residual and check convergence
residual = sum(abs(J(:) - (muchem * psi(:))).^2) * Transf.dx * Transf.dy * Transf.dz; diff = J - (muchem * psi);
residual = inner_product(diff, diff, Transf);
if residual < epsilon if residual < epsilon
fprintf('Convergence reached at iteration %d\n', idx); fprintf('Convergence reached at iteration %d\n', idx);
@ -40,22 +41,22 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
psi_old = psi; psi_old = psi;
% Normalize psi % Normalize psi
Norm = sum(abs(psi_new(:)).^2) * Transf.dx * Transf.dy * Transf.dz; Norm = inner_product(psi_new, psi_new, Transf);
psi = sqrt(Params.N) * psi_new / sqrt(Norm); psi = sqrt(Params.N) * psi_new / sqrt(Norm);
% Write output at specified intervals % Write output at specified intervals
if mod(idx,100) == 0 if mod(idx,100) == 0
% Collect change in energy % Collect change in energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V); E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
E = E/Norm;
Observ.EVec = [Observ.EVec E]; Observ.EVec = [Observ.EVec E];
% Collect chemical potentials % Collect chemical potentials
Observ.mucVec = [Observ.mucVec muchem]; Observ.mucVec = [Observ.mucVec muchem];
% Collect residuals % Collect Normalized residuals
Observ.residual = [Observ.residual residual]; res = this.Calculator.calculateNormalizedResiduals(psi,Params,Transf,VDk,V,muchem);
Observ.residual = [Observ.residual res];
Observ.res_idx = Observ.res_idx + 1; Observ.res_idx = Observ.res_idx + 1;
if this.PlotLive if this.PlotLive
@ -77,8 +78,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
end end
% Change in Energy % Change in Energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V); E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
E = E/Norm;
Observ.EVec = [Observ.EVec E]; Observ.EVec = [Observ.EVec E];
disp('Saving data...'); disp('Saving data...');
@ -110,7 +110,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Compute gradient % Compute gradient
J = compute_gradient(psi, Params, Transf, VDk, V); J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential % Calculate chemical potential
muchem = real(inner_product(psi, J, Transf)) / inner_product(psi, psi, Transf); muchem = real(inner_product(J, psi, Transf)) / inner_product(psi, psi, Transf);
% Calculate residual % Calculate residual
residual = J - (muchem * psi); residual = J - (muchem * psi);
@ -152,17 +152,18 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
psi = (cos(theta).*psi) + (sin(theta).*(p*gamma)); psi = (cos(theta).*psi) + (sin(theta).*(p*gamma));
% Normalize psi % Normalize psi
psi = sqrt(Params.N) * psi / sqrt(abs(inner_product(psi, psi, Transf))); Norm = abs(inner_product(psi, psi, Transf));
psi = sqrt(Params.N) * psi / sqrt(Norm);
i = i + 1; i = i + 1;
% Calculate chemical potential with new psi % Calculate chemical potential with new psi
J = compute_gradient(psi, Params, Transf, VDk, V); J = compute_gradient(psi, Params, Transf, VDk, V);
muchem = real(inner_product(psi, J, Transf)) / inner_product(psi, psi, Transf); muchem = real(inner_product(J, psi, Transf)) / Norm;
if mod(i,100) == 0 if mod(i,100) == 0
% Collect Energy value % Collect Energy value
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / inner_product(psi, psi, Transf); E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
Observ.EVec = [Observ.EVec E]; Observ.EVec = [Observ.EVec E];
% Collect Chemical potential value % Collect Chemical potential value
@ -193,7 +194,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
end end
% Change in Energy % Change in Energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / inner_product(psi, psi, Transf); E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
Observ.EVec = [Observ.EVec E]; Observ.EVec = [Observ.EVec E];
save(sprintf(strcat(this.SaveDirectory, '/Run_%03i/psi_gs.mat'),Params.njob),'psi','muchem','Observ','Transf','Params','VDk','V'); save(sprintf(strcat(this.SaveDirectory, '/Run_%03i/psi_gs.mat'),Params.njob),'psi','muchem','Observ','Transf','Params','VDk','V');
@ -275,6 +276,6 @@ function beta = compute_beta(residual_new, residual_old, Transf)
beta = max(0, real(beta)); beta = max(0, real(beta));
end end
function s = inner_product(u, v, Transf) function s = inner_product(u, v, Transf)
s = sum(conj(u(:)) .* v(:)) * Transf.dx * Transf.dy * Transf.dz; s = sum(conj(u(:)) .* v(:)) * Transf.dx * Transf.dy * Transf.dz;
end end