Useful Quantum ESPRESSO Scripts

November 10, 2024

As a computational materials scientist, I frequently use Quantum ESPRESSO (QE) for density functional theory (DFT) calculations. When I first started, I found the learning curve steep, especially with input file formats and output parsing. "What if there was a way to automate some of these tasks?" I thought.

That's when I began writing my own Python scripts to simplify and automate common QE workflows.

Now I've developed a collection of Python scripts that have significantly streamlined my Quantum ESPRESSO workflows. From generating input files for high-throughput calculations to parsing complex output data, these scripts have saved me countless hours. I'm sharing them here in hopes they'll be useful for the community.

Note: Click on any script title to expand/collapse the code. All scripts are compatible with Python 3.6+ and require minimal dependencies.


1. Input File Generation

Batch Input Generator for Different Lattice Parameters

This script generates multiple QE input files for lattice parameter optimization studies.

batch_lattice_scan.py
Click to view script - Creates publication-quality band structure plots from QE output
#!/usr/bin/env python3
"""
Batch Quantum ESPRESSO input generator for lattice parameter scanning
Author: Yi Cao
Usage: python batch_lattice_scan.py
"""

import os
import numpy as np
from ase import Atoms
from ase.io import write

def generate_qe_input(atoms, lattice_scale, prefix, ecutwfc=50, ecutrho=400):
    """Generate QE input file for given lattice parameter scale"""
    
    # Scale the lattice
    scaled_atoms = atoms.copy()
    scaled_atoms.set_cell(atoms.cell * lattice_scale, scale_atoms=True)
    
    input_text = f"""&CONTROL
    calculation = 'scf'
    prefix = '{prefix}_scale_{lattice_scale:.3f}'
    pseudo_dir = './pseudo/'
    outdir = './tmp/'
    tprnfor = .true.
    tstress = .true.
/

&SYSTEM
    ibrav = 0
    nat = {len(scaled_atoms)}
    ntyp = {len(set(scaled_atoms.get_chemical_symbols()))}
    ecutwfc = {ecutwfc}
    ecutrho = {ecutrho}
    occupations = 'smearing'
    smearing = 'mv'
    degauss = 0.02
/

&ELECTRONS
    conv_thr = 1.0e-8
    mixing_beta = 0.7
/

ATOMIC_SPECIES
"""
    
    # Add atomic species
    species = list(set(scaled_atoms.get_chemical_symbols()))
    masses = {'Si': 28.0855, 'C': 12.0107, 'O': 15.9994, 'H': 1.00794}
    
    for spec in species:
        mass = masses.get(spec, 1.0)
        input_text += f"    {spec} {mass} {spec}.pbe-n-kjpaw_psl.1.0.0.UPF\n"
    
    # Add cell parameters
    input_text += "\nCELL_PARAMETERS angstrom\n"
    for vec in scaled_atoms.cell:
        input_text += f"    {vec[0]:12.8f} {vec[1]:12.8f} {vec[2]:12.8f}\n"
    
    # Add atomic positions
    input_text += "\nATOMIC_POSITIONS angstrom\n"
    for atom in scaled_atoms:
        pos = atom.position
        input_text += f"    {atom.symbol} {pos[0]:12.8f} {pos[1]:12.8f} {pos[2]:12.8f}\n"
    
    # Add k-points
    input_text += "\nK_POINTS automatic\n    4 4 4 0 0 0\n"
    
    return input_text

# Example usage
if __name__ == "__main__":
    # Create example structure (Silicon)
    a = 5.431  # Angstrom
    silicon = Atoms('Si2', 
                    scaled_positions=[(0, 0, 0), (0.25, 0.25, 0.25)],
                    cell=[[0, a/2, a/2], [a/2, 0, a/2], [a/2, a/2, 0]],
                    pbc=True)
    
    # Generate inputs for different scales
    scales = np.linspace(0.95, 1.05, 11)
    
    os.makedirs('lattice_scan', exist_ok=True)
    
    for scale in scales:
        input_content = generate_qe_input(silicon, scale, 'si', ecutwfc=30, ecutrho=240)
        
        filename = f'lattice_scan/si_scale_{scale:.3f}.in'
        with open(filename, 'w') as f:
            f.write(input_content)
        
        print(f"Generated: {filename}")
    
    # Generate submission script
    submit_script = """#!/bin/bash
#SBATCH --job-name=lattice_scan
#SBATCH --nodes=1
#SBATCH --ntasks=16
#SBATCH --time=12:00:00
#SBATCH --partition=regular

module load quantum-espresso

for scale in $(seq 0.950 0.010 1.050); do
    echo "Running scale $scale"
    mpirun -np 16 pw.x < si_scale_${scale}.in > si_scale_${scale}.out
done
"""
    
    with open('lattice_scan/submit.sh', 'w') as f:
        f.write(submit_script)
    
    print("\nDone! Check the 'lattice_scan' directory for input files.")
    print("Submit with: sbatch submit.sh")
▲ Collapse

Convergence Test Generator

Automatically generate inputs for ecutwfc and k-point convergence tests.

convergence_test_generator.py
Click to view script - Creates publication-quality band structure plots from QE output
#!/usr/bin/env python3
"""
Generate QE inputs for convergence testing (ecutwfc and k-points)
Author: Yi Cao
"""

import os
import itertools

def generate_convergence_inputs(template_file, test_type='ecutwfc'):
    """
    Generate input files for convergence testing
    
    Args:
        template_file: Path to template QE input file
        test_type: 'ecutwfc' or 'kpoints'
    """
    
    # Read template
    with open(template_file, 'r') as f:
        template = f.read()
    
    os.makedirs(f'{test_type}_convergence', exist_ok=True)
    
    if test_type == 'ecutwfc':
        # Test ecutwfc values
        ecutwfc_values = [20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100]
        
        for ecutwfc in ecutwfc_values:
            # Replace ecutwfc value
            input_text = template
            input_text = input_text.replace('ecutwfc = 50', f'ecutwfc = {ecutwfc}')
            input_text = input_text.replace('ecutrho = 400', f'ecutrho = {ecutwfc * 8}')
            
            # Update prefix
            input_text = input_text.replace("prefix = 'test'", 
                                          f"prefix = 'ecutwfc_{ecutwfc}'")
            
            # Save file
            filename = f'{test_type}_convergence/ecutwfc_{ecutwfc}.in'
            with open(filename, 'w') as f:
                f.write(input_text)
            
            print(f"Generated: {filename}")
    
    elif test_type == 'kpoints':
        # Test k-point meshes
        k_meshes = [(2,2,2), (3,3,3), (4,4,4), (5,5,5), (6,6,6), 
                    (7,7,7), (8,8,8), (9,9,9), (10,10,10)]
        
        for kmesh in k_meshes:
            # Replace k-points
            input_text = template
            old_kpoints = "K_POINTS automatic\n    4 4 4 0 0 0"
            new_kpoints = f"K_POINTS automatic\n    {kmesh[0]} {kmesh[1]} {kmesh[2]} 0 0 0"
            input_text = input_text.replace(old_kpoints, new_kpoints)
            
            # Update prefix
            kstring = f"{kmesh[0]}x{kmesh[1]}x{kmesh[2]}"
            input_text = input_text.replace("prefix = 'test'", 
                                          f"prefix = 'kpoints_{kstring}'")
            
            # Save file
            filename = f'{test_type}_convergence/kpoints_{kstring}.in'
            with open(filename, 'w') as f:
                f.write(input_text)
            
            print(f"Generated: {filename}")
    
    # Generate analysis script
    analysis_script = f"""#!/usr/bin/env python3
import glob
import matplotlib.pyplot as plt

# Parse output files
files = glob.glob('*.out')
data = []

for file in files:
    with open(file, 'r') as f:
        content = f.read()
        
    # Extract total energy
    for line in content.split('\\n'):
        if '!    total energy' in line:
            energy = float(line.split()[-2])
            
            if '{test_type}' == 'ecutwfc':
                param = int(file.split('_')[1].split('.')[0])
            else:  # kpoints
                param = int(file.split('_')[1].split('x')[0])
            
            data.append((param, energy))
            break

# Sort and plot
data.sort()
params, energies = zip(*data)

plt.figure(figsize=(10, 6))
plt.plot(params, energies, 'o-')
plt.xlabel('{"ecutwfc (Ry)" if test_type == "ecutwfc" else "k-point mesh (NxNxN)"}')
plt.ylabel('Total Energy (Ry)')
plt.title('Convergence Test Results')
plt.grid(True)
plt.savefig('convergence.png', dpi=150)
plt.show()

# Print convergence info
print("\\nConvergence Analysis:")
for i in range(1, len(energies)):
    diff = abs(energies[i] - energies[i-1]) * 13.6057  # Convert to eV
    print(f"{params[i]}: ΔE = {diff:.4f} eV")
"""
    
    with open(f'{test_type}_convergence/analyze.py', 'w') as f:
        f.write(analysis_script)
    
    print(f"\nAnalysis script created: {test_type}_convergence/analyze.py")

# Example usage
if __name__ == "__main__":
    # Create a template file first
    template = """&CONTROL
    calculation = 'scf'
    prefix = 'test'
    pseudo_dir = './pseudo/'
    outdir = './tmp/'
/

&SYSTEM
    ibrav = 2
    celldm(1) = 10.26
    nat = 2
    ntyp = 1
    ecutwfc = 50
    ecutrho = 400
/

&ELECTRONS
    conv_thr = 1.0e-8
/

ATOMIC_SPECIES
    Si 28.0855 Si.pbe-n-kjpaw_psl.1.0.0.UPF

ATOMIC_POSITIONS alat
    Si 0.00 0.00 0.00
    Si 0.25 0.25 0.25

K_POINTS automatic
    4 4 4 0 0 0
"""
    
    with open('template.in', 'w') as f:
        f.write(template)
    
    # Generate convergence tests
    generate_convergence_inputs('template.in', 'ecutwfc')
    generate_convergence_inputs('template.in', 'kpoints')
▲ Collapse

2. Output Parsing and Analysis

Comprehensive Output Parser

Extract all relevant data from QE output files including energies, forces, stress, and convergence information.

qe_output_parser.py
Click to view script - Creates publication-quality band structure plots from QE output
#!/usr/bin/env python3
"""
Comprehensive Quantum ESPRESSO output parser
Author: Yi Cao
Usage: python qe_output_parser.py output.out
"""

import re
import sys
import json
import numpy as np
from dataclasses import dataclass, asdict
from typing import List, Dict, Optional, Tuple

@dataclass
class QEOutput:
    """Data structure for QE output information"""
    filename: str
    calculation_type: str
    total_energy: Optional[float] = None
    fermi_energy: Optional[float] = None
    total_force: Optional[float] = None
    pressure: Optional[float] = None
    volume: Optional[float] = None
    magnetization: Optional[float] = None
    cpu_time: Optional[float] = None
    wall_time: Optional[float] = None
    scf_cycles: Optional[int] = None
    forces: Optional[List[List[float]]] = None
    stress_tensor: Optional[List[List[float]]] = None
    atomic_positions: Optional[List[Tuple[str, List[float]]]] = None
    cell_parameters: Optional[List[List[float]]] = None
    eigenvalues: Optional[Dict] = None
    convergence_achieved: bool = False

class QEOutputParser:
    """Parser for Quantum ESPRESSO output files"""
    
    def __init__(self, filename: str):
        self.filename = filename
        with open(filename, 'r') as f:
            self.content = f.read()
        self.lines = self.content.split('\n')
    
    def parse(self) -> QEOutput:
        """Parse the output file and return QEOutput object"""
        output = QEOutput(filename=self.filename, calculation_type=self._get_calculation_type())
        
        # Parse various sections
        output.total_energy = self._parse_total_energy()
        output.fermi_energy = self._parse_fermi_energy()
        output.forces = self._parse_forces()
        output.total_force = self._calculate_total_force(output.forces)
        output.stress_tensor = self._parse_stress()
        output.pressure = self._parse_pressure()
        output.volume = self._parse_volume()
        output.magnetization = self._parse_magnetization()
        output.atomic_positions = self._parse_final_positions()
        output.cell_parameters = self._parse_final_cell()
        output.cpu_time, output.wall_time = self._parse_timing()
        output.scf_cycles = self._parse_scf_cycles()
        output.convergence_achieved = self._check_convergence()
        output.eigenvalues = self._parse_eigenvalues()
        
        return output
    
    def _get_calculation_type(self) -> str:
        """Determine calculation type"""
        for line in self.lines:
            if 'calculation' in line and '=' in line:
                return line.split('=')[1].strip().strip("'\"")
        return 'unknown'
    
    def _parse_total_energy(self) -> Optional[float]:
        """Parse total energy in Ry"""
        pattern = r'!\s+total energy\s+=\s+([-\d.]+)\s+Ry'
        match = re.search(pattern, self.content)
        if match:
            return float(match.group(1))
        return None
    
    def _parse_fermi_energy(self) -> Optional[float]:
        """Parse Fermi energy in eV"""
        pattern = r'the Fermi energy is\s+([-\d.]+)\s+ev'
        match = re.search(pattern, self.content, re.IGNORECASE)
        if match:
            return float(match.group(1))
        return None
    
    def _parse_forces(self) -> Optional[List[List[float]]]:
        """Parse atomic forces"""
        forces = []
        in_forces = False
        
        for i, line in enumerate(self.lines):
            if 'Forces acting on atoms' in line:
                in_forces = True
                continue
            
            if in_forces:
                if 'atom' in line and 'type' in line and 'force' in line:
                    parts = line.split('=')
                    if len(parts) >= 2:
                        force_str = parts[-1].strip()
                        force_values = [float(x) for x in force_str.split()]
                        forces.append(force_values)
                elif line.strip() == '' and forces:
                    break
        
        return forces if forces else None
    
    def _calculate_total_force(self, forces: Optional[List[List[float]]]) -> Optional[float]:
        """Calculate total force magnitude"""
        if not forces:
            return None
        
        total = 0.0
        for force in forces:
            total += sum(f**2 for f in force)**0.5
        return total
    
    def _parse_stress(self) -> Optional[List[List[float]]]:
        """Parse stress tensor in kbar"""
        stress = []
        in_stress = False
        
        for line in self.lines:
            if 'total   stress' in line and 'kbar' in line:
                in_stress = True
                continue
            
            if in_stress:
                if line.strip() and not line.strip().startswith('('):
                    try:
                        values = [float(x) for x in line.split()[:3]]
                        stress.append(values)
                        if len(stress) == 3:
                            break
                    except:
                        break
        
        return stress if len(stress) == 3 else None
    
    def _parse_pressure(self) -> Optional[float]:
        """Parse pressure in kbar"""
        pattern = r'total\s+stress.*P=\s*([-\d.]+)'
        match = re.search(pattern, self.content)
        if match:
            return float(match.group(1))
        return None
    
    def _parse_volume(self) -> Optional[float]:
        """Parse unit cell volume"""
        pattern = r'unit-cell volume\s+=\s+([-\d.]+)\s+$a\.u\.$\^3'
        match = re.search(pattern, self.content)
        if match:
            return float(match.group(1))
        return None
    
    def _parse_magnetization(self) -> Optional[float]:
        """Parse total magnetization"""
        pattern = r'total magnetization\s+=\s+([-\d.]+)\s+Bohr mag/cell'
        match = re.search(pattern, self.content)
        if match:
            return float(match.group(1))
        return None
    
    def _parse_final_positions(self) -> Optional[List[Tuple[str, List[float]]]]:
        """Parse final atomic positions"""
        positions = []
        
        # Find the last occurrence of ATOMIC_POSITIONS
        last_pos_idx = -1
        for i, line in enumerate(self.lines):
            if 'ATOMIC_POSITIONS' in line:
                last_pos_idx = i
        
        if last_pos_idx >= 0:
            i = last_pos_idx + 1
            while i < len(self.lines):
                line = self.lines[i].strip()
                if not line or line.startswith('End') or any(keyword in line for keyword in ['CELL_PARAMETERS', 'K_POINTS']):
                    break
                
                parts = line.split()
                if len(parts) >= 4:
                    atom_type = parts[0]
                    coords = [float(parts[j]) for j in range(1, 4)]
                    positions.append((atom_type, coords))
                
                i += 1
        
        return positions if positions else None
    
    def _parse_final_cell(self) -> Optional[List[List[float]]]:
        """Parse final cell parameters"""
        cell = []
        
        # Find the last occurrence of CELL_PARAMETERS
        last_cell_idx = -1
        for i, line in enumerate(self.lines):
            if 'CELL_PARAMETERS' in line:
                last_cell_idx = i
        
        if last_cell_idx >= 0:
            for i in range(last_cell_idx + 1, min(last_cell_idx + 4, len(self.lines))):
                line = self.lines[i].strip()
                if line:
                    try:
                        values = [float(x) for x in line.split()[:3]]
                        cell.append(values)
                    except:
                        break
        
        return cell if len(cell) == 3 else None
    
    def _parse_timing(self) -> Tuple[Optional[float], Optional[float]]:
        """Parse CPU and wall time"""
        cpu_time = None
        wall_time = None
        
        pattern = r'PWSCF\s+:\s+(\d+m\s*[\d.]+s)\s+CPU\s+(\d+m\s*[\d.]+s)\s+WALL'
        match = re.search(pattern, self.content)
        
        if match:
            cpu_str = match.group(1)
            wall_str = match.group(2)
            
            # Convert to seconds
            cpu_time = self._time_str_to_seconds(cpu_str)
            wall_time = self._time_str_to_seconds(wall_str)
        
        return cpu_time, wall_time
    
    def _time_str_to_seconds(self, time_str: str) -> float:
        """Convert time string like '1m30.5s' to seconds"""
        total = 0.0
        
        # Extract minutes
        m_match = re.search(r'(\d+)m', time_str)
        if m_match:
            total += int(m_match.group(1)) * 60
        
        # Extract seconds
        s_match = re.search(r'([\d.]+)s', time_str)
        if s_match:
            total += float(s_match.group(1))
        
        return total
    
    def _parse_scf_cycles(self) -> Optional[int]:
        """Count number of SCF cycles"""
        cycles = len(re.findall(r'iteration #\s*\d+', self.content))
        return cycles if cycles > 0 else None
    
    def _check_convergence(self) -> bool:
        """Check if calculation converged"""
        return 'convergence has been achieved' in self.content
    
    def _parse_eigenvalues(self) -> Optional[Dict]:
        """Parse eigenvalues at high-symmetry points"""
        eigenvalues = {}
        current_k = None
        
        for i, line in enumerate(self.lines):
            if 'k =' in line and 'bands (ev):' in line:
                # Extract k-point
                k_match = re.search(r'k\s*=\s*([-\d.]+)\s+([-\d.]+)\s+([-\d.]+)', line)
                if k_match:
                    current_k = tuple(float(k_match.group(j)) for j in range(1, 4))
                    eigenvalues[current_k] = []
            
            elif current_k is not None and line.strip() and not 'k =' in line:
                # Try to parse eigenvalues
                try:
                    values = [float(x) for x in line.split()]
                    eigenvalues[current_k].extend(values)
                except:
                    current_k = None
        
        return eigenvalues if eigenvalues else None
    
    def to_dict(self) -> dict:
        """Convert parsed output to dictionary"""
        return asdict(self.parse())
    
    def to_json(self, filename: str):
        """Save parsed output to JSON file"""
        with open(filename, 'w') as f:
            json.dump(self.to_dict(), f, indent=2, default=str)
    
    def summary(self) -> str:
        """Generate a summary of the calculation"""
        output = self.parse()
        
        summary = f"""
Quantum ESPRESSO Output Summary
===============================
File: {output.filename}
Calculation Type: {output.calculation_type}
Converged: {output.convergence_achieved}

Energy & Forces:
  Total Energy: {output.total_energy:.6f} Ry ({output.total_energy * 13.6057:.6f} eV) if output.total_energy else 'N/A'
  Fermi Energy: {output.fermi_energy:.6f} eV if output.fermi_energy else 'N/A'
  Total Force: {output.total_force:.6f} Ry/au if output.total_force else 'N/A'

Structure:
  Volume: {output.volume:.2f} (a.u.)^3 if output.volume else 'N/A'
  Pressure: {output.pressure:.2f} kbar if output.pressure else 'N/A'

Performance:
  SCF Cycles: {output.scf_cycles if output.scf_cycles else 'N/A'}
  CPU Time: {output.cpu_time:.1f} s if output.cpu_time else 'N/A'
  Wall Time: {output.wall_time:.1f} s if output.wall_time else 'N/A'
"""
        return summary

# Command-line interface
if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python qe_output_parser.py ")
        sys.exit(1)
    
    parser = QEOutputParser(sys.argv[1])
    print(parser.summary())
    
    # Save to JSON
    json_file = sys.argv[1].replace('.out', '_parsed.json')
    parser.to_json(json_file)
    print(f"\nDetailed results saved to: {json_file}")
▲ Collapse

Band Structure Extractor

Extract and plot band structure data from QE bands.x output files.

band_structure_extractor.py
Click to view script - Creates publication-quality band structure plots from QE output
#!/usr/bin/env python3
"""
Band structure extractor and plotter for Quantum ESPRESSO
Author: Yi Cao
Usage: python band_structure_extractor.py bands.out
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import re
import sys

# Set publication-quality defaults
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Arial']
rcParams['font.size'] = 12
rcParams['axes.linewidth'] = 1.5
rcParams['xtick.major.width'] = 1.5
rcParams['ytick.major.width'] = 1.5

class BandStructure:
    """Class for handling QE band structure data"""
    
    def __init__(self, filename):
        self.filename = filename
        self.kpoints = []
        self.bands = []
        self.fermi_energy = None
        self.high_symmetry_points = []
        self.reciprocal_vectors = None
        
    def parse_bands_output(self):
        """Parse bands.x output file"""
        with open(self.filename, 'r') as f:
            lines = f.readlines()
        
        # Find number of k-points and bands
        for i, line in enumerate(lines):
            if 'number of k points=' in line:
                self.num_kpoints = int(line.split('=')[1].strip())
            elif 'number of bands=' in line:
                self.num_bands = int(line.split('=')[1].strip())
            elif 'Fermi energy' in line:
                self.fermi_energy = float(line.split('=')[1].strip())
            elif 'k =' in line:
                # Parse k-point coordinates
                kpoint = [float(x) for x in re.findall(r'[-+]?\d*\.\d+|\d+', line)]
                self.kpoints.append(kpoint)
            elif 'bands (ev)' in line:
                # Parse band energies
                band_energies = []
                for j in range(self.num_bands):
                    band_line = lines[i + 1 + j].strip()
                    energies = [float(x) for x in band_line.split()]
                    band_energies.append(energies)
                self.bands.append(band_energies)
            elif 'high symmetry points' in line:
                # Parse high symmetry points
                self.high_symmetry_points.append(line.split()[2:])
            elif 'reciprocal lattice vectors' in line:
                # Parse reciprocal lattice vectors
                self.reciprocal_vectors = []
                for j in range(3):
                    vec_line = lines[i + 1 + j].strip()
                    vec = [float(x) for x in vec_line.split()]
                    self.reciprocal_vectors.append(vec)
        self.kpoints = np.array(self.kpoints)
        self.bands = np.array(self.bands)
        self.high_symmetry_points = np.array(self.high_symmetry_points)
    def plot_band_structure(self):
        """Plot the band structure"""
        if not self.kpoints or not self.bands:
            raise ValueError("No band structure data to plot. Please parse the output first.")
        
        # Convert to numpy arrays
        kpoints = np.array(self.kpoints)
        bands = np.array(self.bands)
        
        # Create figure and axis
        fig, ax = plt.subplots(figsize=(8, 6))
        
        # Plot each band
        for i in range(bands.shape[1]):
            ax.plot(kpoints[:, 0], bands[:, i] - self.fermi_energy, color='blue', lw=1)
        
        # Add high symmetry points
        for point in self.high_symmetry_points:
            ax.axvline(x=float(point[0]), color='red', linestyle='--', lw=0.5)
            ax.text(float(point[0]), -5, point[1], color='red', ha='center', va='bottom')
        
        # Set labels and title
        ax.set_xlabel('Wave Vector (k)')
        ax.set_ylabel('Energy (eV)')
        ax.set_title('Band Structure')
        
        # Add grid and adjust limits
        ax.grid(True)
        ax.set_ylim(-5, 5)
        
        # Show plot
        plt.tight_layout()
        plt.show()
    def save_plot(self, filename='band_structure.png'):
        """Save the band structure plot to a file"""
        if not self.kpoints or not self.bands:
            raise ValueError("No band structure data to save. Please parse the output first.")
        
        # Convert to numpy arrays
        kpoints = np.array(self.kpoints)
        bands = np.array(self.bands)
        
        # Create figure and axis
        fig, ax = plt.subplots(figsize=(8, 6))
        
        # Plot each band
        for i in range(bands.shape[1]):
            ax.plot(kpoints[:, 0], bands[:, i] - self.fermi_energy, color='blue', lw=1)
        
        # Add high symmetry points
        for point in self.high_symmetry_points:
            ax.axvline(x=float(point[0]), color='red', linestyle='--', lw=0.5)
            ax.text(float(point[0]), -5, point[1], color='red', ha='center', va='bottom')
        
        # Set labels and title
        ax.set_xlabel('Wave Vector (k)')
        ax.set_ylabel('Energy (eV)')
        ax.set_title('Band Structure')
        
        # Add grid and adjust limits
        ax.grid(True)
        ax.set_ylim(-5, 5)
        
        # Save plot
        plt.tight_layout()
        plt.savefig(filename, dpi=300)
        print(f"Band structure plot saved to {filename}")
    
▲ Collapse

2025 Yi Cao. CC BY-NC-SA 4.0.