Skip to content

Commit

Permalink
feat: file object passed to open (#709)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new utility function `open_file` for improved file
handling across various modules.
- Enhanced type annotations for multiple functions to specify
`FileType`, improving code clarity and type safety.
  
- **Bug Fixes**
- Improved file handling robustness by replacing the built-in `open`
function with the custom `open_file` function in several modules.

- **Tests**
- Added unit tests for the new `open_file` utility function to ensure
reliable functionality.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] committed Sep 3, 2024
1 parent 5df6acd commit 6bf41e3
Show file tree
Hide file tree
Showing 29 changed files with 289 additions and 90 deletions.
10 changes: 6 additions & 4 deletions dpdata/abacus/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import numpy as np

from dpdata.utils import open_file

from .scf import (
bohr2ang,
get_cell,
Expand Down Expand Up @@ -156,12 +158,12 @@ def get_frame(fname):
path_in = os.path.join(fname, "INPUT")
else:
raise RuntimeError("invalid input")
with open(path_in) as fp:
with open_file(path_in) as fp:
inlines = fp.read().split("\n")
geometry_path_in = get_geometry_in(fname, inlines) # base dir of STRU
path_out = get_path_out(fname, inlines)

with open(geometry_path_in) as fp:
with open_file(geometry_path_in) as fp:
geometry_inlines = fp.read().split("\n")
celldm, cell = get_cell(geometry_inlines)
atom_names, natoms, types, coords = get_coords(
Expand All @@ -172,11 +174,11 @@ def get_frame(fname):
# ndump = int(os.popen("ls -l %s | grep 'md_pos_' | wc -l" %path_out).readlines()[0])
# number of dumped geometry files
# coords = get_coords_from_cif(ndump, dump_freq, atom_names, natoms, types, path_out, cell)
with open(os.path.join(path_out, "MD_dump")) as fp:
with open_file(os.path.join(path_out, "MD_dump")) as fp:
dumplines = fp.read().split("\n")
coords, cells, force, stress = get_coords_from_dump(dumplines, natoms)
ndump = np.shape(coords)[0]
with open(os.path.join(path_out, "running_md.log")) as fp:
with open_file(os.path.join(path_out, "running_md.log")) as fp:
outlines = fp.read().split("\n")
energy = get_energy(outlines, ndump, dump_freq)

Expand Down
8 changes: 5 additions & 3 deletions dpdata/abacus/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np

from dpdata.utils import open_file

from .scf import (
bohr2ang,
collect_force,
Expand Down Expand Up @@ -174,10 +176,10 @@ def get_frame(fname):
path_in = os.path.join(fname, "INPUT")
else:
raise RuntimeError("invalid input")
with open(path_in) as fp:
with open_file(path_in) as fp:
inlines = fp.read().split("\n")
geometry_path_in = get_geometry_in(fname, inlines) # base dir of STRU
with open(geometry_path_in) as fp:
with open_file(geometry_path_in) as fp:
geometry_inlines = fp.read().split("\n")
celldm, cell = get_cell(geometry_inlines)
atom_names, natoms, types, coord_tmp = get_coords(
Expand All @@ -186,7 +188,7 @@ def get_frame(fname):

logf = get_log_file(fname, inlines)
assert os.path.isfile(logf), f"Error: can not find {logf}"
with open(logf) as f1:
with open_file(logf) as f1:
lines = f1.readlines()

atomnumber = 0
Expand Down
10 changes: 6 additions & 4 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np

from dpdata.utils import open_file

from ..unit import EnergyConversion, LengthConversion, PressureConversion

bohr2ang = LengthConversion("bohr", "angstrom").value()
Expand Down Expand Up @@ -253,17 +255,17 @@ def get_frame(fname):
if not CheckFile(path_in):
return data

with open(path_in) as fp:
with open_file(path_in) as fp:
inlines = fp.read().split("\n")

geometry_path_in = get_geometry_in(fname, inlines)
path_out = get_path_out(fname, inlines)
if not (CheckFile(geometry_path_in) and CheckFile(path_out)):
return data

with open(geometry_path_in) as fp:
with open_file(geometry_path_in) as fp:
geometry_inlines = fp.read().split("\n")
with open(path_out) as fp:
with open_file(path_out) as fp:
outlines = fp.read().split("\n")

celldm, cell = get_cell(geometry_inlines)
Expand Down Expand Up @@ -338,7 +340,7 @@ def get_nele_from_stru(geometry_inlines):

def get_frame_from_stru(fname):
assert isinstance(fname, str)
with open(fname) as fp:
with open_file(fname) as fp:
geometry_inlines = fp.read().split("\n")
nele = get_nele_from_stru(geometry_inlines)
inlines = ["ntype %d" % nele]
Expand Down
7 changes: 4 additions & 3 deletions dpdata/amber/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dpdata.amber.mask import pick_by_amber_mask
from dpdata.unit import EnergyConversion
from dpdata.utils import open_file

from ..periodic_table import ELEMENTS

Expand Down Expand Up @@ -51,7 +52,7 @@ def read_amber_traj(
flag_atom_numb = False
amber_types = []
atomic_number = []
with open(parm7_file) as f:
with open_file(parm7_file) as f:
for line in f:
if line.startswith("%FLAG"):
flag_atom_type = line.startswith("%FLAG AMBER_ATOM_TYPE")
Expand Down Expand Up @@ -101,14 +102,14 @@ def read_amber_traj(
# load energy from mden_file or mdout_file
energies = []
if mden_file is not None and os.path.isfile(mden_file):
with open(mden_file) as f:
with open_file(mden_file) as f:
for line in f:
if line.startswith("L6"):
s = line.split()
if s[2] != "E_pot":
energies.append(float(s[2]))
elif mdout_file is not None and os.path.isfile(mdout_file):
with open(mdout_file) as f:
with open_file(mdout_file) as f:
for line in f:
if "EPtot" in line:
s = line.split()
Expand Down
14 changes: 10 additions & 4 deletions dpdata/amber/sqm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from dpdata.periodic_table import ELEMENTS
from dpdata.unit import EnergyConversion
from dpdata.utils import open_file

if TYPE_CHECKING:
from dpdata.utils import FileType

kcal2ev = EnergyConversion("kcal_mol", "eV").value()

Expand All @@ -14,15 +20,15 @@
READ_FORCES = 7


def parse_sqm_out(fname):
def parse_sqm_out(fname: FileType):
"""Read atom symbols, charges and coordinates from ambertools sqm.out file."""
atom_symbols = []
coords = []
charges = []
forces = []
energies = []

with open(fname) as f:
with open_file(fname) as f:
flag = START
for line in f:
if line.startswith(" Total SCF energy"):
Expand Down Expand Up @@ -81,7 +87,7 @@ def parse_sqm_out(fname):
return data


def make_sqm_in(data, fname=None, frame_idx=0, **kwargs):
def make_sqm_in(data, fname: FileType | None = None, frame_idx=0, **kwargs):
symbols = [data["atom_names"][ii] for ii in data["atom_types"]]
atomic_numbers = [ELEMENTS.index(ss) + 1 for ss in symbols]
charge = kwargs.get("charge", 0)
Expand Down Expand Up @@ -109,6 +115,6 @@ def make_sqm_in(data, fname=None, frame_idx=0, **kwargs):
f"{data['coords'][frame_idx][ii, 2]:.6f}",
)
if fname is not None:
with open(fname, "w") as fp:
with open_file(fname, "w") as fp:
fp.write(ret)
return ret
3 changes: 2 additions & 1 deletion dpdata/deepmd/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

import dpdata
from dpdata.utils import open_file

from .raw import load_type

Expand Down Expand Up @@ -172,7 +173,7 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
except OSError:
pass
if data.get("nopbc", False):
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
with open_file(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
pass
# allow custom dtypes
labels = "energies" in data
Expand Down
5 changes: 3 additions & 2 deletions dpdata/deepmd/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

import dpdata
from dpdata.utils import open_file


def load_type(folder, type_map=None):
Expand All @@ -17,7 +18,7 @@ def load_type(folder, type_map=None):
data["atom_names"] = []
# if find type_map.raw, use it
if os.path.isfile(os.path.join(folder, "type_map.raw")):
with open(os.path.join(folder, "type_map.raw")) as fp:
with open_file(os.path.join(folder, "type_map.raw")) as fp:
my_type_map = fp.read().split()
# else try to use arg type_map
elif type_map is not None:
Expand Down Expand Up @@ -140,7 +141,7 @@ def dump(folder, data):
except OSError:
pass
if data.get("nopbc", False):
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
with open_file(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
pass
# allow custom dtypes
labels = "energies" in data
Expand Down
15 changes: 12 additions & 3 deletions dpdata/dftbplus/output.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from dpdata.utils import open_file

if TYPE_CHECKING:
from dpdata.utils import FileType


def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.ndarray]:
def read_dftb_plus(
fn_1: FileType, fn_2: FileType
) -> tuple[str, np.ndarray, float, np.ndarray]:
"""Read from DFTB+ input and output.
Parameters
Expand All @@ -29,7 +38,7 @@ def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.nda
symbols = None
forces = None
energy = None
with open(fn_1) as f:
with open_file(fn_1) as f:
flag = 0
for line in f:
if flag == 1:
Expand All @@ -49,7 +58,7 @@ def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.nda
flag += 1
if flag == 7:
flag = 0
with open(fn_2) as f:
with open_file(fn_2) as f:
flag = 0
for line in f:
if line.startswith("Total Forces"):
Expand Down
11 changes: 9 additions & 2 deletions dpdata/gaussian/log.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from dpdata.utils import open_file

if TYPE_CHECKING:
from dpdata.utils import FileType

from ..periodic_table import ELEMENTS
from ..unit import EnergyConversion, ForceConversion, LengthConversion

Expand All @@ -12,7 +19,7 @@
symbols = ["X"] + ELEMENTS


def to_system_data(file_name, md=False):
def to_system_data(file_name: FileType, md=False):
"""Read Gaussian log file.
Parameters
Expand Down Expand Up @@ -43,7 +50,7 @@ def to_system_data(file_name, md=False):
nopbc = True
coords = None

with open(file_name) as fp:
with open_file(file_name) as fp:
for line in fp:
if line.startswith(" SCF Done"):
# energies
Expand Down
10 changes: 8 additions & 2 deletions dpdata/gromacs/gro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING

import numpy as np

from dpdata.utils import open_file

if TYPE_CHECKING:
from dpdata.utils import FileType

from ..unit import LengthConversion

nm2ang = LengthConversion("nm", "angstrom").value()
Expand Down Expand Up @@ -48,9 +54,9 @@ def _get_cell(line):
return cell


def file_to_system_data(fname, format_atom_name=True, **kwargs):
def file_to_system_data(fname: FileType, format_atom_name=True, **kwargs):
system = {"coords": [], "cells": []}
with open(fname) as fp:
with open_file(fname) as fp:
frame = 0
while True:
flag = fp.readline()
Expand Down
10 changes: 8 additions & 2 deletions dpdata/lammps/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

import os
import sys
from typing import TYPE_CHECKING

import numpy as np

from dpdata.utils import open_file

if TYPE_CHECKING:
from dpdata.utils import FileType

lib_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(lib_path)
import warnings
Expand Down Expand Up @@ -169,11 +175,11 @@ def box2dumpbox(orig, box):
return bounds, tilt


def load_file(fname, begin=0, step=1):
def load_file(fname: FileType, begin=0, step=1):
lines = []
buff = []
cc = -1
with open(fname) as fp:
with open_file(fname) as fp:
while True:
line = fp.readline().rstrip("\n")
if not line:
Expand Down
Loading

0 comments on commit 6bf41e3

Please sign in to comment.