Minor corrections - theta value now seems to decrease beyond 10 iterations but the resulting state is still incorrect.

This commit is contained in:
Karthik 2025-04-28 16:15:40 +02:00
parent b47d81e68d
commit 58566397ae

View File

@ -22,13 +22,14 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
for idx = 1:this.MaxIterationsForGD
% Compute gradient
J = compute_gradient(psi, Params, Transf, VDk, V);
J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential
muchem = sum(real(psi(:)' .* J(:))) / sum(abs(psi(:)).^2);
muchem = real(inner_product(J, psi, Transf)) / inner_product(psi, psi, Transf);
% Calculate residual and check convergence
residual = sum(abs(J(:) - (muchem * psi(:))).^2) * Transf.dx * Transf.dy * Transf.dz;
diff = J - (muchem * psi);
residual = inner_product(diff, diff, Transf);
if residual < epsilon
fprintf('Convergence reached at iteration %d\n', idx);
@ -40,22 +41,22 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
psi_old = psi;
% Normalize psi
Norm = sum(abs(psi_new(:)).^2) * Transf.dx * Transf.dy * Transf.dz;
Norm = inner_product(psi_new, psi_new, Transf);
psi = sqrt(Params.N) * psi_new / sqrt(Norm);
% Write output at specified intervals
if mod(idx,100) == 0
% Collect change in energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V);
E = E/Norm;
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
Observ.EVec = [Observ.EVec E];
% Collect chemical potentials
Observ.mucVec = [Observ.mucVec muchem];
% Collect residuals
Observ.residual = [Observ.residual residual];
% Collect Normalized residuals
res = this.Calculator.calculateNormalizedResiduals(psi,Params,Transf,VDk,V,muchem);
Observ.residual = [Observ.residual res];
Observ.res_idx = Observ.res_idx + 1;
if this.PlotLive
@ -77,8 +78,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
end
% Change in Energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V);
E = E/Norm;
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
Observ.EVec = [Observ.EVec E];
disp('Saving data...');
@ -110,7 +110,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
% Compute gradient
J = compute_gradient(psi, Params, Transf, VDk, V);
% Calculate chemical potential
muchem = real(inner_product(psi, J, Transf)) / inner_product(psi, psi, Transf);
muchem = real(inner_product(J, psi, Transf)) / inner_product(psi, psi, Transf);
% Calculate residual
residual = J - (muchem * psi);
@ -152,17 +152,18 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
psi = (cos(theta).*psi) + (sin(theta).*(p*gamma));
% Normalize psi
psi = sqrt(Params.N) * psi / sqrt(abs(inner_product(psi, psi, Transf)));
Norm = abs(inner_product(psi, psi, Transf));
psi = sqrt(Params.N) * psi / sqrt(Norm);
i = i + 1;
% Calculate chemical potential with new psi
J = compute_gradient(psi, Params, Transf, VDk, V);
muchem = real(inner_product(psi, J, Transf)) / inner_product(psi, psi, Transf);
muchem = real(inner_product(J, psi, Transf)) / Norm;
if mod(i,100) == 0
% Collect Energy value
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / inner_product(psi, psi, Transf);
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
Observ.EVec = [Observ.EVec E];
% Collect Chemical potential value
@ -193,7 +194,7 @@ function [psi] = runGradientDescent(this,psi,Params,Transf,VDk,V,Observ)
end
% Change in Energy
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / inner_product(psi, psi, Transf);
E = this.Calculator.calculateTotalEnergy(psi,Params,Transf,VDk,V) / Norm;
Observ.EVec = [Observ.EVec E];
save(sprintf(strcat(this.SaveDirectory, '/Run_%03i/psi_gs.mat'),Params.njob),'psi','muchem','Observ','Transf','Params','VDk','V');
@ -275,6 +276,6 @@ function beta = compute_beta(residual_new, residual_old, Transf)
beta = max(0, real(beta));
end
function s = inner_product(u, v, Transf)
s = sum(conj(u(:)) .* v(:)) * Transf.dx * Transf.dy * Transf.dz;
function s = inner_product(u, v, Transf)
s = sum(conj(u(:)) .* v(:)) * Transf.dx * Transf.dy * Transf.dz;
end