diff --git a/dpdata/abacus/md.py b/dpdata/abacus/md.py index fa184177..9e0a416a 100644 --- a/dpdata/abacus/md.py +++ b/dpdata/abacus/md.py @@ -5,6 +5,8 @@ import numpy as np +from dpdata.utils import open_file + from .scf import ( bohr2ang, get_cell, @@ -156,12 +158,12 @@ def get_frame(fname): path_in = os.path.join(fname, "INPUT") else: raise RuntimeError("invalid input") - with open(path_in) as fp: + with open_file(path_in) as fp: inlines = fp.read().split("\n") geometry_path_in = get_geometry_in(fname, inlines) # base dir of STRU path_out = get_path_out(fname, inlines) - with open(geometry_path_in) as fp: + with open_file(geometry_path_in) as fp: geometry_inlines = fp.read().split("\n") celldm, cell = get_cell(geometry_inlines) atom_names, natoms, types, coords = get_coords( @@ -172,11 +174,11 @@ def get_frame(fname): # ndump = int(os.popen("ls -l %s | grep 'md_pos_' | wc -l" %path_out).readlines()[0]) # number of dumped geometry files # coords = get_coords_from_cif(ndump, dump_freq, atom_names, natoms, types, path_out, cell) - with open(os.path.join(path_out, "MD_dump")) as fp: + with open_file(os.path.join(path_out, "MD_dump")) as fp: dumplines = fp.read().split("\n") coords, cells, force, stress = get_coords_from_dump(dumplines, natoms) ndump = np.shape(coords)[0] - with open(os.path.join(path_out, "running_md.log")) as fp: + with open_file(os.path.join(path_out, "running_md.log")) as fp: outlines = fp.read().split("\n") energy = get_energy(outlines, ndump, dump_freq) diff --git a/dpdata/abacus/relax.py b/dpdata/abacus/relax.py index 976243b8..7a4abf35 100644 --- a/dpdata/abacus/relax.py +++ b/dpdata/abacus/relax.py @@ -4,6 +4,8 @@ import numpy as np +from dpdata.utils import open_file + from .scf import ( bohr2ang, collect_force, @@ -174,10 +176,10 @@ def get_frame(fname): path_in = os.path.join(fname, "INPUT") else: raise RuntimeError("invalid input") - with open(path_in) as fp: + with open_file(path_in) as fp: inlines = fp.read().split("\n") geometry_path_in = get_geometry_in(fname, inlines) # base dir of STRU - with open(geometry_path_in) as fp: + with open_file(geometry_path_in) as fp: geometry_inlines = fp.read().split("\n") celldm, cell = get_cell(geometry_inlines) atom_names, natoms, types, coord_tmp = get_coords( @@ -186,7 +188,7 @@ def get_frame(fname): logf = get_log_file(fname, inlines) assert os.path.isfile(logf), f"Error: can not find {logf}" - with open(logf) as f1: + with open_file(logf) as f1: lines = f1.readlines() atomnumber = 0 diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index c06f8cd3..93e3d6e1 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -6,6 +6,8 @@ import numpy as np +from dpdata.utils import open_file + from ..unit import EnergyConversion, LengthConversion, PressureConversion bohr2ang = LengthConversion("bohr", "angstrom").value() @@ -253,7 +255,7 @@ def get_frame(fname): if not CheckFile(path_in): return data - with open(path_in) as fp: + with open_file(path_in) as fp: inlines = fp.read().split("\n") geometry_path_in = get_geometry_in(fname, inlines) @@ -261,9 +263,9 @@ def get_frame(fname): if not (CheckFile(geometry_path_in) and CheckFile(path_out)): return data - with open(geometry_path_in) as fp: + with open_file(geometry_path_in) as fp: geometry_inlines = fp.read().split("\n") - with open(path_out) as fp: + with open_file(path_out) as fp: outlines = fp.read().split("\n") celldm, cell = get_cell(geometry_inlines) @@ -338,7 +340,7 @@ def get_nele_from_stru(geometry_inlines): def get_frame_from_stru(fname): assert isinstance(fname, str) - with open(fname) as fp: + with open_file(fname) as fp: geometry_inlines = fp.read().split("\n") nele = get_nele_from_stru(geometry_inlines) inlines = ["ntype %d" % nele] diff --git a/dpdata/amber/md.py b/dpdata/amber/md.py index f3217fbd..cb4f2d25 100644 --- a/dpdata/amber/md.py +++ b/dpdata/amber/md.py @@ -7,6 +7,7 @@ from dpdata.amber.mask import pick_by_amber_mask from dpdata.unit import EnergyConversion +from dpdata.utils import open_file from ..periodic_table import ELEMENTS @@ -51,7 +52,7 @@ def read_amber_traj( flag_atom_numb = False amber_types = [] atomic_number = [] - with open(parm7_file) as f: + with open_file(parm7_file) as f: for line in f: if line.startswith("%FLAG"): flag_atom_type = line.startswith("%FLAG AMBER_ATOM_TYPE") @@ -101,14 +102,14 @@ def read_amber_traj( # load energy from mden_file or mdout_file energies = [] if mden_file is not None and os.path.isfile(mden_file): - with open(mden_file) as f: + with open_file(mden_file) as f: for line in f: if line.startswith("L6"): s = line.split() if s[2] != "E_pot": energies.append(float(s[2])) elif mdout_file is not None and os.path.isfile(mdout_file): - with open(mdout_file) as f: + with open_file(mdout_file) as f: for line in f: if "EPtot" in line: s = line.split() diff --git a/dpdata/amber/sqm.py b/dpdata/amber/sqm.py index 1be3802a..93e41f9a 100644 --- a/dpdata/amber/sqm.py +++ b/dpdata/amber/sqm.py @@ -1,9 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from dpdata.periodic_table import ELEMENTS from dpdata.unit import EnergyConversion +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType kcal2ev = EnergyConversion("kcal_mol", "eV").value() @@ -14,7 +20,7 @@ READ_FORCES = 7 -def parse_sqm_out(fname): +def parse_sqm_out(fname: FileType): """Read atom symbols, charges and coordinates from ambertools sqm.out file.""" atom_symbols = [] coords = [] @@ -22,7 +28,7 @@ def parse_sqm_out(fname): forces = [] energies = [] - with open(fname) as f: + with open_file(fname) as f: flag = START for line in f: if line.startswith(" Total SCF energy"): @@ -81,7 +87,7 @@ def parse_sqm_out(fname): return data -def make_sqm_in(data, fname=None, frame_idx=0, **kwargs): +def make_sqm_in(data, fname: FileType | None = None, frame_idx=0, **kwargs): symbols = [data["atom_names"][ii] for ii in data["atom_types"]] atomic_numbers = [ELEMENTS.index(ss) + 1 for ss in symbols] charge = kwargs.get("charge", 0) @@ -109,6 +115,6 @@ def make_sqm_in(data, fname=None, frame_idx=0, **kwargs): f"{data['coords'][frame_idx][ii, 2]:.6f}", ) if fname is not None: - with open(fname, "w") as fp: + with open_file(fname, "w") as fp: fp.write(ret) return ret diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index ad638408..5ba3914e 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -8,6 +8,7 @@ import numpy as np import dpdata +from dpdata.utils import open_file from .raw import load_type @@ -172,7 +173,7 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True): except OSError: pass if data.get("nopbc", False): - with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc: + with open_file(os.path.join(folder, "nopbc"), "w") as fw_nopbc: pass # allow custom dtypes labels = "energies" in data diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index 717c73f9..f8ddcaf3 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -6,6 +6,7 @@ import numpy as np import dpdata +from dpdata.utils import open_file def load_type(folder, type_map=None): @@ -17,7 +18,7 @@ def load_type(folder, type_map=None): data["atom_names"] = [] # if find type_map.raw, use it if os.path.isfile(os.path.join(folder, "type_map.raw")): - with open(os.path.join(folder, "type_map.raw")) as fp: + with open_file(os.path.join(folder, "type_map.raw")) as fp: my_type_map = fp.read().split() # else try to use arg type_map elif type_map is not None: @@ -140,7 +141,7 @@ def dump(folder, data): except OSError: pass if data.get("nopbc", False): - with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc: + with open_file(os.path.join(folder, "nopbc"), "w") as fw_nopbc: pass # allow custom dtypes labels = "energies" in data diff --git a/dpdata/dftbplus/output.py b/dpdata/dftbplus/output.py index 0f10c3ac..49fdd2b1 100644 --- a/dpdata/dftbplus/output.py +++ b/dpdata/dftbplus/output.py @@ -1,9 +1,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType + -def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.ndarray]: +def read_dftb_plus( + fn_1: FileType, fn_2: FileType +) -> tuple[str, np.ndarray, float, np.ndarray]: """Read from DFTB+ input and output. Parameters @@ -29,7 +38,7 @@ def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.nda symbols = None forces = None energy = None - with open(fn_1) as f: + with open_file(fn_1) as f: flag = 0 for line in f: if flag == 1: @@ -49,7 +58,7 @@ def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.nda flag += 1 if flag == 7: flag = 0 - with open(fn_2) as f: + with open_file(fn_2) as f: flag = 0 for line in f: if line.startswith("Total Forces"): diff --git a/dpdata/gaussian/log.py b/dpdata/gaussian/log.py index 204cf464..08a65b9d 100644 --- a/dpdata/gaussian/log.py +++ b/dpdata/gaussian/log.py @@ -1,7 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType + from ..periodic_table import ELEMENTS from ..unit import EnergyConversion, ForceConversion, LengthConversion @@ -12,7 +19,7 @@ symbols = ["X"] + ELEMENTS -def to_system_data(file_name, md=False): +def to_system_data(file_name: FileType, md=False): """Read Gaussian log file. Parameters @@ -43,7 +50,7 @@ def to_system_data(file_name, md=False): nopbc = True coords = None - with open(file_name) as fp: + with open_file(file_name) as fp: for line in fp: if line.startswith(" SCF Done"): # energies diff --git a/dpdata/gromacs/gro.py b/dpdata/gromacs/gro.py index aca2443b..fe83e0c5 100644 --- a/dpdata/gromacs/gro.py +++ b/dpdata/gromacs/gro.py @@ -2,9 +2,15 @@ from __future__ import annotations import re +from typing import TYPE_CHECKING import numpy as np +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType + from ..unit import LengthConversion nm2ang = LengthConversion("nm", "angstrom").value() @@ -48,9 +54,9 @@ def _get_cell(line): return cell -def file_to_system_data(fname, format_atom_name=True, **kwargs): +def file_to_system_data(fname: FileType, format_atom_name=True, **kwargs): system = {"coords": [], "cells": []} - with open(fname) as fp: + with open_file(fname) as fp: frame = 0 while True: flag = fp.readline() diff --git a/dpdata/lammps/dump.py b/dpdata/lammps/dump.py index f0ade2b0..72fee27d 100644 --- a/dpdata/lammps/dump.py +++ b/dpdata/lammps/dump.py @@ -3,9 +3,15 @@ import os import sys +from typing import TYPE_CHECKING import numpy as np +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType + lib_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(lib_path) import warnings @@ -169,11 +175,11 @@ def box2dumpbox(orig, box): return bounds, tilt -def load_file(fname, begin=0, step=1): +def load_file(fname: FileType, begin=0, step=1): lines = [] buff = [] cc = -1 - with open(fname) as fp: + with open_file(fname) as fp: while True: line = fp.readline().rstrip("\n") if not line: diff --git a/dpdata/openmx/omx.py b/dpdata/openmx/omx.py index d3afff00..aae9b578 100644 --- a/dpdata/openmx/omx.py +++ b/dpdata/openmx/omx.py @@ -1,8 +1,15 @@ #!/usr/bin/python3 from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType + from ..unit import ( EnergyConversion, ForceConversion, @@ -98,12 +105,12 @@ def load_cells(lines): # load atom_names, atom_numbs, atom_types, cells -def load_param_file(fname, mdname): - with open(fname) as dat_file: +def load_param_file(fname: FileType, mdname: FileType): + with open_file(fname) as dat_file: lines = dat_file.readlines() atom_names, atom_types, atom_numbs = load_atom(lines) - with open(mdname) as md_file: + with open_file(mdname) as md_file: lines = md_file.readlines() cells = load_cells(lines) return atom_names, atom_numbs, atom_types, cells @@ -133,15 +140,15 @@ def load_coords(lines, atom_names, natoms): return coords -def load_data(mdname, atom_names, natoms): - with open(mdname) as md_file: +def load_data(mdname: FileType, atom_names, natoms): + with open_file(mdname) as md_file: lines = md_file.readlines() coords = load_coords(lines, atom_names, natoms) steps = [str(i) for i in range(1, coords.shape[0] + 1)] return coords, steps -def to_system_data(fname, mdname): +def to_system_data(fname: FileType, mdname: FileType): data = {} ( data["atom_names"], @@ -194,7 +201,7 @@ def load_force(lines, atom_names, atom_numbs): # load energy, force def to_system_label(fname, mdname): atom_names, atom_numbs, atom_types, cells = load_param_file(fname, mdname) - with open(mdname) as md_file: + with open_file(mdname) as md_file: lines = md_file.readlines() energy = load_energy(lines) force = load_force(lines, atom_names, atom_numbs) diff --git a/dpdata/orca/output.py b/dpdata/orca/output.py index a23013fd..a0915162 100644 --- a/dpdata/orca/output.py +++ b/dpdata/orca/output.py @@ -1,9 +1,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType + -def read_orca_sp_output(fn: str) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]: +def read_orca_sp_output( + fn: FileType, +) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]: """Read from ORCA output. Note that both the energy and the gradient should be printed. @@ -28,7 +37,7 @@ def read_orca_sp_output(fn: str) -> tuple[np.ndarray, np.ndarray, float, np.ndar symbols = None forces = None energy = None - with open(fn) as f: + with open_file(fn) as f: flag = 0 for line in f: if flag in (1, 3, 4): diff --git a/dpdata/plugins/abacus.py b/dpdata/plugins/abacus.py index eb2d7786..e3367b35 100644 --- a/dpdata/plugins/abacus.py +++ b/dpdata/plugins/abacus.py @@ -1,9 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import dpdata.abacus.md import dpdata.abacus.relax import dpdata.abacus.scf from dpdata.format import Format +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType @Format.register("abacus/stru") @@ -12,7 +18,7 @@ class AbacusSTRUFormat(Format): def from_system(self, file_name, **kwargs): return dpdata.abacus.scf.get_frame_from_stru(file_name) - def to_system(self, data, file_name, frame_idx=0, **kwargs): + def to_system(self, data, file_name: FileType, frame_idx=0, **kwargs): """Dump the system into ABACUS STRU format file. Parameters @@ -46,7 +52,7 @@ def to_system(self, data, file_name, frame_idx=0, **kwargs): numerical_descriptor=numerical_descriptor, mass=mass, ) - with open(file_name, "w") as fp: + with open_file(file_name, "w") as fp: fp.write(stru_string) diff --git a/dpdata/plugins/amber.py b/dpdata/plugins/amber.py index 42fce552..361e0d8a 100644 --- a/dpdata/plugins/amber.py +++ b/dpdata/plugins/amber.py @@ -8,6 +8,7 @@ import dpdata.amber.sqm from dpdata.driver import Driver, Minimizer from dpdata.format import Format +from dpdata.utils import open_file @Format.register("amber/md") @@ -143,7 +144,7 @@ def label(self, data: dict) -> dict: [*self.sqm_exec.split(), "-O", "-i", inp_fn, "-o", out_fn] ) except sp.CalledProcessError as e: - with open(out_fn) as f: + with open_file(out_fn) as f: raise RuntimeError( "Run sqm failed! Output:\n" + f.read() ) from e diff --git a/dpdata/plugins/gaussian.py b/dpdata/plugins/gaussian.py index b55447b9..55bee5a4 100644 --- a/dpdata/plugins/gaussian.py +++ b/dpdata/plugins/gaussian.py @@ -3,16 +3,21 @@ import os import subprocess as sp import tempfile +from typing import TYPE_CHECKING import dpdata.gaussian.gjf import dpdata.gaussian.log from dpdata.driver import Driver from dpdata.format import Format +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType @Format.register("gaussian/log") class GaussianLogFormat(Format): - def from_labeled_system(self, file_name, md=False, **kwargs): + def from_labeled_system(self, file_name: FileType, md=False, **kwargs): try: return dpdata.gaussian.log.to_system_data(file_name, md=md) except AssertionError: @@ -21,7 +26,7 @@ def from_labeled_system(self, file_name, md=False, **kwargs): @Format.register("gaussian/md") class GaussianMDFormat(Format): - def from_labeled_system(self, file_name, **kwargs): + def from_labeled_system(self, file_name: FileType, **kwargs): return GaussianLogFormat().from_labeled_system(file_name, md=True) @@ -29,7 +34,7 @@ def from_labeled_system(self, file_name, **kwargs): class GaussiaGJFFormat(Format): """Gaussian input file.""" - def from_system(self, file_name: str, **kwargs): + def from_system(self, file_name: FileType, **kwargs): """Read Gaussian input file. Parameters @@ -39,11 +44,11 @@ def from_system(self, file_name: str, **kwargs): **kwargs : dict keyword arguments """ - with open(file_name) as fp: + with open_file(file_name) as fp: text = fp.read() return dpdata.gaussian.gjf.read_gaussian_input(text) - def to_system(self, data: dict, file_name: str, **kwargs): + def to_system(self, data: dict, file_name: FileType, **kwargs): """Generate Gaussian input file. Parameters @@ -56,7 +61,7 @@ def to_system(self, data: dict, file_name: str, **kwargs): Other parameters to make input files. See :meth:`dpdata.gaussian.gjf.make_gaussian_input` """ text = dpdata.gaussian.gjf.make_gaussian_input(data, **kwargs) - with open(file_name, "w") as fp: + with open_file(file_name, "w") as fp: fp.write(text) @@ -110,7 +115,7 @@ def label(self, data: dict) -> dict: try: sp.check_output([*self.gaussian_exec.split(), inp_fn]) except sp.CalledProcessError as e: - with open(out_fn) as f: + with open_file(out_fn) as f: out = f.read() raise RuntimeError("Run gaussian failed! Output:\n" + out) from e labeled_system.append(dpdata.LabeledSystem(out_fn, fmt="gaussian/log")) diff --git a/dpdata/plugins/gromacs.py b/dpdata/plugins/gromacs.py index 12dece71..a7066bbc 100644 --- a/dpdata/plugins/gromacs.py +++ b/dpdata/plugins/gromacs.py @@ -1,7 +1,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import dpdata.gromacs.gro from dpdata.format import Format +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType @Format.register("gro") @@ -23,7 +29,9 @@ def from_system(self, file_name, format_atom_name=True, **kwargs): file_name, format_atom_name=format_atom_name, **kwargs ) - def to_system(self, data, file_name=None, frame_idx=-1, **kwargs): + def to_system( + self, data, file_name: FileType | None = None, frame_idx=-1, **kwargs + ): """Dump the system in gromacs .gro format. Parameters @@ -52,5 +60,5 @@ def to_system(self, data, file_name=None, frame_idx=-1, **kwargs): if file_name is None: return gro_str else: - with open(file_name, "w+") as fp: + with open_file(file_name, "w+") as fp: fp.write(gro_str) diff --git a/dpdata/plugins/lammps.py b/dpdata/plugins/lammps.py index 65e7f570..0327d444 100644 --- a/dpdata/plugins/lammps.py +++ b/dpdata/plugins/lammps.py @@ -1,20 +1,26 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import dpdata.lammps.dump import dpdata.lammps.lmp from dpdata.format import Format +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType @Format.register("lmp") @Format.register("lammps/lmp") class LAMMPSLmpFormat(Format): @Format.post("shift_orig_zero") - def from_system(self, file_name, type_map=None, **kwargs): - with open(file_name) as fp: + def from_system(self, file_name: FileType, type_map=None, **kwargs): + with open_file(file_name) as fp: lines = [line.rstrip("\n") for line in fp] return dpdata.lammps.lmp.to_system_data(lines, type_map) - def to_system(self, data, file_name, frame_idx=0, **kwargs): + def to_system(self, data, file_name: FileType, frame_idx=0, **kwargs): """Dump the system in lammps data format. Parameters @@ -30,7 +36,7 @@ def to_system(self, data, file_name, frame_idx=0, **kwargs): """ assert frame_idx < len(data["coords"]) w_str = dpdata.lammps.lmp.from_system_data(data, frame_idx) - with open(file_name, "w") as fp: + with open_file(file_name, "w") as fp: fp.write(w_str) diff --git a/dpdata/plugins/n2p2.py b/dpdata/plugins/n2p2.py index b70d6e6f..28942ff5 100644 --- a/dpdata/plugins/n2p2.py +++ b/dpdata/plugins/n2p2.py @@ -1,8 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from dpdata.format import Format +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType from ..unit import EnergyConversion, ForceConversion, LengthConversion @@ -44,7 +50,7 @@ class N2P2Format(Format): For more information about the n2p2 format, please refer to https://compphysvienna.github.io/n2p2/topics/cfg_file.html """ - def from_labeled_system(self, file_name, **kwargs): + def from_labeled_system(self, file_name: FileType, **kwargs): """Read from n2p2 format. Parameters @@ -67,7 +73,7 @@ def from_labeled_system(self, file_name, **kwargs): natom0 = None natoms0 = None atom_types0 = None - with open(file_name) as file: + with open_file(file_name) as file: for line in file: line = line.strip() # Remove leading/trailing whitespace if line.lower() == "begin": @@ -155,7 +161,7 @@ def from_labeled_system(self, file_name, **kwargs): "forces": forces, } - def to_labeled_system(self, data, file_name, **kwargs): + def to_labeled_system(self, data, file_name: FileType, **kwargs): """Write n2p2 format. By default, LabeledSystem.to will fallback to System.to. @@ -193,5 +199,5 @@ def to_labeled_system(self, data, file_name, **kwargs): buff.append(f"energy {energy:15.6f}") buff.append(f"charge {0:15.6f}") buff.append("end") - with open(file_name, "w") as fp: + with open_file(file_name, "w") as fp: fp.write("\n".join(buff)) diff --git a/dpdata/plugins/orca.py b/dpdata/plugins/orca.py index 9dc32c32..7a0b806c 100644 --- a/dpdata/plugins/orca.py +++ b/dpdata/plugins/orca.py @@ -1,11 +1,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from dpdata.format import Format from dpdata.orca.output import read_orca_sp_output from dpdata.unit import EnergyConversion, ForceConversion +if TYPE_CHECKING: + from dpdata.utils import FileType + energy_convert = EnergyConversion("hartree", "eV").value() force_convert = ForceConversion("hartree/bohr", "eV/angstrom").value() @@ -18,12 +23,12 @@ class ORCASPOutFormat(Format): printed into the output file. """ - def from_labeled_system(self, file_name: str, **kwargs) -> dict: + def from_labeled_system(self, file_name: FileType, **kwargs) -> dict: """Read from ORCA single point energy output. Parameters ---------- - file_name : str + file_name : FileType file name **kwargs keyword arguments diff --git a/dpdata/plugins/psi4.py b/dpdata/plugins/psi4.py index a0cf00e4..2bbfc232 100644 --- a/dpdata/plugins/psi4.py +++ b/dpdata/plugins/psi4.py @@ -1,11 +1,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from dpdata.format import Format from dpdata.psi4.input import write_psi4_input from dpdata.psi4.output import read_psi4_output from dpdata.unit import EnergyConversion, ForceConversion +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType energy_convert = EnergyConversion("hartree", "eV").value() force_convert = ForceConversion("hartree/bohr", "eV/angstrom").value() @@ -19,12 +25,12 @@ class PSI4OutFormat(Format): printed into the output file. """ - def from_labeled_system(self, file_name: str, **kwargs) -> dict: + def from_labeled_system(self, file_name: FileType, **kwargs) -> dict: """Read from Psi4 output. Parameters ---------- - file_name : str + file_name : FileType file name **kwargs keyword arguments @@ -61,7 +67,7 @@ class PSI4InputFormat(Format): def to_system( self, data: dict, - file_name: str, + file_name: FileType, method: str, basis: str, charge: int = 0, @@ -91,7 +97,7 @@ def to_system( keyword arguments """ types = np.array(data["atom_names"])[data["atom_types"]] - with open(file_name, "w") as fout: + with open_file(file_name, "w") as fout: fout.write( write_psi4_input( types=types, diff --git a/dpdata/plugins/pwmat.py b/dpdata/plugins/pwmat.py index 80f219b6..ba3dab16 100644 --- a/dpdata/plugins/pwmat.py +++ b/dpdata/plugins/pwmat.py @@ -1,10 +1,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import dpdata.pwmat.atomconfig import dpdata.pwmat.movement from dpdata.format import Format +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType @Format.register("movement") @@ -47,12 +53,12 @@ def from_labeled_system( @Format.register("pwmat/final.config") class PwmatAtomconfigFormat(Format): @Format.post("rot_lower_triangular") - def from_system(self, file_name, **kwargs): - with open(file_name) as fp: + def from_system(self, file_name: FileType, **kwargs): + with open_file(file_name) as fp: lines = [line.rstrip("\n") for line in fp] return dpdata.pwmat.atomconfig.to_system_data(lines) - def to_system(self, data, file_name, frame_idx=0, *args, **kwargs): + def to_system(self, data, file_name: FileType, frame_idx=0, *args, **kwargs): """Dump the system in pwmat atom.config format. Parameters @@ -70,5 +76,5 @@ def to_system(self, data, file_name, frame_idx=0, *args, **kwargs): """ assert frame_idx < len(data["coords"]) w_str = dpdata.pwmat.atomconfig.from_system_data(data, frame_idx) - with open(file_name, "w") as fp: + with open_file(file_name, "w") as fp: fp.write(w_str) diff --git a/dpdata/plugins/vasp.py b/dpdata/plugins/vasp.py index d0681ceb..0160bde2 100644 --- a/dpdata/plugins/vasp.py +++ b/dpdata/plugins/vasp.py @@ -1,12 +1,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import dpdata.vasp.outcar import dpdata.vasp.poscar import dpdata.vasp.xml from dpdata.format import Format -from dpdata.utils import uniq_atom_names +from dpdata.utils import open_file, uniq_atom_names + +if TYPE_CHECKING: + from dpdata.utils import FileType @Format.register("poscar") @@ -15,14 +20,14 @@ @Format.register("vasp/contcar") class VASPPoscarFormat(Format): @Format.post("rot_lower_triangular") - def from_system(self, file_name, **kwargs): - with open(file_name) as fp: + def from_system(self, file_name: FileType, **kwargs): + with open_file(file_name) as fp: lines = [line.rstrip("\n") for line in fp] data = dpdata.vasp.poscar.to_system_data(lines) data = uniq_atom_names(data) return data - def to_system(self, data, file_name, frame_idx=0, **kwargs): + def to_system(self, data, file_name: FileType, frame_idx=0, **kwargs): """Dump the system in vasp POSCAR format. Parameters @@ -37,7 +42,7 @@ def to_system(self, data, file_name, frame_idx=0, **kwargs): other parameters """ w_str = VASPStringFormat().to_system(data, frame_idx=frame_idx) - with open(file_name, "w") as fp: + with open_file(file_name, "w") as fp: fp.write(w_str) diff --git a/dpdata/plugins/xyz.py b/dpdata/plugins/xyz.py index 322bf77c..d56a8618 100644 --- a/dpdata/plugins/xyz.py +++ b/dpdata/plugins/xyz.py @@ -1,8 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from dpdata.format import Format +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems from dpdata.xyz.xyz import coord_to_xyz, xyz_to_coord @@ -16,16 +22,16 @@ class XYZFormat(Format): >>> s.to("xyz", "a.xyz") """ - def to_system(self, data, file_name, **kwargs): + def to_system(self, data, file_name: FileType, **kwargs): buff = [] types = np.array(data["atom_names"])[data["atom_types"]] for cc in data["coords"]: buff.append(coord_to_xyz(cc, types)) - with open(file_name, "w") as fp: + with open_file(file_name, "w") as fp: fp.write("\n".join(buff)) - def from_system(self, file_name, **kwargs): - with open(file_name) as fp: + def from_system(self, file_name: FileType, **kwargs): + with open_file(file_name) as fp: coords, types = xyz_to_coord(fp.read()) atom_names, atom_types, atom_numbs = np.unique( types, return_inverse=True, return_counts=True diff --git a/dpdata/psi4/output.py b/dpdata/psi4/output.py index c06eb182..c3594ffb 100644 --- a/dpdata/psi4/output.py +++ b/dpdata/psi4/output.py @@ -1,11 +1,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from dpdata.unit import LengthConversion +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType -def read_psi4_output(fn: str) -> tuple[str, np.ndarray, float, np.ndarray]: +def read_psi4_output(fn: FileType) -> tuple[str, np.ndarray, float, np.ndarray]: """Read from Psi4 output. Note that both the energy and the gradient should be printed. @@ -31,7 +37,7 @@ def read_psi4_output(fn: str) -> tuple[str, np.ndarray, float, np.ndarray]: forces = None energy = None length_unit = None - with open(fn) as f: + with open_file(fn) as f: flag = 0 for line in f: if flag in (1, 3, 4, 5, 6): diff --git a/dpdata/qe/scf.py b/dpdata/qe/scf.py index 37e5fbab..f8670860 100755 --- a/dpdata/qe/scf.py +++ b/dpdata/qe/scf.py @@ -5,6 +5,8 @@ import numpy as np +from dpdata.utils import open_file + ry2ev = 13.605693009 bohr2ang = 0.52917721067 kbar2evperang3 = 1e3 / 1.602176621e6 @@ -142,9 +144,9 @@ def get_frame(fname): path_out = fname[1] else: raise RuntimeError("invalid input") - with open(path_out) as fp: + with open_file(path_out) as fp: outlines = fp.read().split("\n") - with open(path_in) as fp: + with open_file(path_in) as fp: inlines = fp.read().split("\n") cell = get_cell(inlines) atom_names, natoms, types, coords = get_coords(inlines, cell) diff --git a/dpdata/qe/traj.py b/dpdata/qe/traj.py index 1fbf0f71..b4be303a 100644 --- a/dpdata/qe/traj.py +++ b/dpdata/qe/traj.py @@ -2,9 +2,15 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING import numpy as np +from dpdata.utils import open_file + +if TYPE_CHECKING: + from dpdata.utils import FileType + from ..unit import ( EnergyConversion, ForceConversion, @@ -87,8 +93,8 @@ def load_atom_types(lines, natoms, atom_names): return np.array(ret, dtype=int) -def load_param_file(fname): - with open(fname) as fp: +def load_param_file(fname: FileType): + with open_file(fname) as fp: lines = fp.read().split("\n") natoms = int(load_key(lines, "nat")) ntypes = int(load_key(lines, "ntyp")) @@ -127,11 +133,11 @@ def _load_pos_block(fp, natoms): return blk, ss -def load_data(fname, natoms, begin=0, step=1, convert=1.0): +def load_data(fname: FileType, natoms, begin=0, step=1, convert=1.0): coords = [] steps = [] cc = 0 - with open(fname) as fp: + with open_file(fname) as fp: while True: blk, ss = _load_pos_block(fp, natoms) if blk is None: @@ -147,7 +153,7 @@ def load_data(fname, natoms, begin=0, step=1, convert=1.0): # def load_pos(fname, natoms) : # coords = [] -# with open(fname) as fp: +# with open_file(fname) as fp: # while True: # blk = _load_pos_block(fp, natoms) # # print(blk) @@ -164,7 +170,7 @@ def load_energy(fname, begin=0, step=1): steps = [] for ii in data[begin::step, 0]: steps.append("%d" % ii) - with open(fname) as fp: + with open_file(fname) as fp: while True: line = fp.readline() if not line: @@ -178,7 +184,7 @@ def load_energy(fname, begin=0, step=1): # def load_force(fname, natoms) : # coords = [] -# with open(fname) as fp: +# with open_file(fname) as fp: # while True: # blk = _load_pos_block(fp, natoms) # # print(blk) diff --git a/dpdata/utils.py b/dpdata/utils.py index e008120e..8942bd54 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -1,7 +1,10 @@ from __future__ import annotations +import io +import os import sys -from typing import overload +from contextlib import contextmanager +from typing import TYPE_CHECKING, Generator, overload if sys.version_info >= (3, 8): from typing import Literal @@ -129,3 +132,44 @@ def uniq_atom_names(data): def utf8len(s: str) -> int: """Return the byte length of a string.""" return len(s.encode("utf-8")) + + +if TYPE_CHECKING: + FileType = io.IOBase | str | os.PathLike + + +@contextmanager +def open_file(file: FileType, *args, **kwargs) -> Generator[io.IOBase, None, None]: + """A context manager that yields a file object. + + Parameters + ---------- + file : file object or file path + A file object or a file path. + + Yields + ------ + file : io.IOBase + A file object. + *args + parameters to open + **kwargs + other parameters + + Raises + ------ + ValueError + If file is not a file object or a file + + Examples + -------- + >>> with open_file("file.txt") as file: + ... print(file.read()) + """ + if isinstance(file, io.IOBase): + yield file + elif isinstance(file, (str, os.PathLike)): + with open(file, *args, **kwargs) as f: + yield f + else: + raise ValueError("file must be a file object or a file path.") diff --git a/tests/test_read_file.py b/tests/test_read_file.py new file mode 100644 index 00000000..e7dca54e --- /dev/null +++ b/tests/test_read_file.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import io +import unittest +from pathlib import Path + +from dpdata.utils import open_file + + +class TestReadFile(unittest.TestCase): + def test_open_file_from_string_io(self): + string_io = io.StringIO("Hello, world!") + with open_file(string_io) as file: + self.assertEqual(file.read(), "Hello, world!") + + def test_open_file_from_file_str(self): + with open_file("/dev/null") as file: + self.assertEqual(file.read(), Path("/dev/null").read_text()) + + def test_open_file_from_file_path(self): + with open_file(Path("/dev/null")) as file: + self.assertEqual(file.read(), Path("/dev/null").read_text())