Skip to content

Commit

Permalink
Combine obs & ds
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen authored and yngve-sk committed May 28, 2024
1 parent 0366947 commit a5273b5
Show file tree
Hide file tree
Showing 73 changed files with 3,080 additions and 1,515 deletions.
97 changes: 15 additions & 82 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def _save_param_ensemble_array_to_disk(
ensemble, param_group, realization, param_ensemble_array[:, i]
)

ensemble.unify_parameters()


def _load_param_ensemble_array(
ensemble: Ensemble,
Expand All @@ -132,83 +134,6 @@ def _load_param_ensemble_array(
return config_node.load_parameters(ensemble, param_group, iens_active_index)


def _get_observations_and_responses(
ensemble: Ensemble,
selected_observations: Iterable[str],
iens_active_index: npt.NDArray[np.int_],
) -> Tuple[
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.str_],
npt.NDArray[np.str_],
]:
"""Fetches and aligns selected observations with their corresponding simulated responses from an ensemble."""
filtered_responses = []
observation_keys = []
observation_values = []
observation_errors = []
indexes = []
observations = ensemble.experiment.observations
for obs in selected_observations:
observation = observations[obs]
group = observation.attrs["response"]
all_responses = ensemble.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
all_responses = all_responses.reindex(
time=observation.time,
method="nearest",
tolerance="1s", # type: ignore
)
try:
observations_and_responses = observation.merge(all_responses, join="left")
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched index for: "
f"Observation: {obs} attached to response: {group}"
) from e

observation_keys.append([obs] * observations_and_responses["observations"].size)

if group == "summary":
indexes.append(
[
np.datetime_as_string(e, unit="s")
for e in observations_and_responses["time"].data
]
)
else:
indexes.append(
[
f"{e[0]}, {e[1]}"
for e in zip(
list(observations_and_responses["report_step"].data)
* len(observations_and_responses["index"].data),
observations_and_responses["index"].data,
)
]
)

observation_values.append(
observations_and_responses["observations"].data.ravel()
)
observation_errors.append(observations_and_responses["std"].data.ravel())

filtered_responses.append(
observations_and_responses["values"]
.transpose(..., "realization")
.values.reshape((-1, len(observations_and_responses.realization)))
)
ensemble.load_responses.cache_clear()
return (
np.concatenate(filtered_responses),
np.concatenate(observation_values),
np.concatenate(observation_errors),
np.concatenate(observation_keys),
np.concatenate(indexes),
)


def _expand_wildcards(
input_list: npt.NDArray[np.str_], patterns: List[str]
) -> List[str]:
Expand All @@ -234,11 +159,19 @@ def _load_observations_and_responses(
List[ObservationAndResponseSnapshot],
],
]:
S, observations, errors, obs_keys, indexes = _get_observations_and_responses(
ensemble,
selected_observations,
iens_active_index,
)
try:
observed_responses_data = ensemble.get_measured_data(
[*selected_observations], iens_active_index
)
except KeyError as e:
# Exit early if some observations are pointing to non-existing responses
raise ErtAnalysisError("No active observations for update step") from e

S = observed_responses_data.vec_of_realization_values()
observations = observed_responses_data.vec_of_obs_values()
errors = observed_responses_data.vec_of_errors()
obs_keys = observed_responses_data.vec_of_obs_names()
indexes = observed_responses_data.vec_of_obs_indexes()

# Inflating measurement errors by a factor sqrt(global_std_scaling) as shown
# in for example evensen2018 - Analysis of iterative ensemble smoothers for
Expand Down
1 change: 1 addition & 0 deletions src/ert/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _read_parameters(
error_msg += str(err)
result = LoadResult(LoadStatus.LOAD_FAILURE, error_msg)
logger.warning(f"Failed to load: {run_arg.iens}", exc_info=err)

return result


Expand Down
2 changes: 1 addition & 1 deletion src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run_cli(args: Namespace, _: Any = None) -> None:
logger.info("Config contains forward model step %s", fm_step.name)

ert = EnKFMain(ert_config)
if not ert_config.observations and args.mode not in [
if not ert_config.observation_keys and args.mode not in [
ENSEMBLE_EXPERIMENT_MODE,
TEST_RUN_MODE,
WORKFLOW_MODE,
Expand Down
6 changes: 3 additions & 3 deletions src/ert/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from ert.config.gen_data_config import GenDataConfig
from ert.config.summary_config import SummaryConfig

from .analysis_config import AnalysisConfig
from .analysis_module import AnalysisModule, ESSettings, IESSettings
from .ensemble_config import EnsembleConfig
Expand Down Expand Up @@ -32,6 +29,8 @@
queue_string_options,
)
from .response_config import ResponseConfig
from .response_properties import ResponseTypes
from .summary_config import SummaryConfig
from .summary_observation import SummaryObservation
from .surface_config import SurfaceConfig
from .workflow import Workflow
Expand Down Expand Up @@ -64,6 +63,7 @@
"QueueConfig",
"QueueSystem",
"ResponseConfig",
"ResponseTypes",
"SummaryConfig",
"SummaryObservation",
"SurfaceConfig",
Expand Down
6 changes: 0 additions & 6 deletions src/ert/config/enkf_observation_implementation_type.py

This file was deleted.

5 changes: 3 additions & 2 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
parse,
)
from .queue_config import QueueConfig
from .response_properties import ResponseTypes
from .workflow import Workflow
from .workflow_job import ErtScriptLoadFailure, WorkflowJob

Expand Down Expand Up @@ -111,8 +112,8 @@ def __post_init__(self) -> None:

self.observations = self._create_observations_and_find_summary_keys()

if "summary" in self.observations:
summary_ds = self.observations["summary"]
if ResponseTypes.summary in self.observations:
summary_ds = self.observations[ResponseTypes.summary]
names_in_ds = summary_ds["name"].data.tolist()
self.summary_keys.extend(names_in_ds)

Expand Down
12 changes: 6 additions & 6 deletions src/ert/config/ext_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ def write_to_runpath(
Path.mkdir(file_path.parent, exist_ok=True, parents=True)

data: MutableDataType = {}
for da in ensemble.load_parameters(self.name, real_nr)["values"]:
assert isinstance(da, xr.DataArray)
name = str(da.names.values)
df = ensemble.load_parameters(self.name, real_nr)["values"].to_dataframe()
as_dict = df.to_dict()["values"]
for k, v in as_dict.items():
try:
outer, inner = name.split("\0")
outer, inner = k.split("\0")

if outer not in data:
data[outer] = {}
data[outer][inner] = float(da) # type: ignore
data[outer][inner] = float(v) # type: ignore
except ValueError:
data[name] = float(da)
data[k] = float(v)

with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def parse_observation(args: ObsArgs) -> Dict[str, ObsVector]:
try:
return {
obs_key: ObsVector(
ResponseTypes.GEN_DATA,
ResponseTypes.gen_data,
obs_key,
config_node.name,
{
Expand Down
7 changes: 6 additions & 1 deletion src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,12 @@ def write_to_runpath(
f" is of size {len(self.transform_functions)}, expected {array.size}"
)

data = dict(zip(array["names"].values.tolist(), array.values.tolist()))
data = dict(
zip(
array["names"].values.tolist(),
array.values.flatten().tolist(),
)
)

log10_data = {
tf.name: math.log(data[tf.name], 10)
Expand Down
137 changes: 137 additions & 0 deletions src/ert/config/obs_commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from datetime import datetime, timedelta
from typing import List, Tuple

import numpy as np
import numpy.typing as npt

from ert.config.parsing import ConfigWarning
from ert.config.parsing.observations_parser import (
DateValues,
ErrorValues,
ObservationConfigError,
SummaryValues,
)

DEFAULT_TIME_DELTA = timedelta(seconds=30)


def get_time(date_dict: DateValues, start_time: datetime) -> Tuple[datetime, str]:
if date_dict.date is not None:
date_str = date_dict.date
try:
return datetime.fromisoformat(date_str), f"DATE={date_str}"
except ValueError:
try:
date = datetime.strptime(date_str, "%d/%m/%Y")
ConfigWarning.ert_context_warn(
f"Deprecated time format {date_str}."
" Please use ISO date format YYYY-MM-DD",
date_str,
)
return date, f"DATE={date_str}"
except ValueError as err:
raise ObservationConfigError.with_context(
f"Unsupported date format {date_str}."
" Please use ISO date format",
date_str,
) from err

if date_dict.days is not None:
days = date_dict.days
return start_time + timedelta(days=days), f"DAYS={days}"
if date_dict.hours is not None:
hours = date_dict.hours
return start_time + timedelta(hours=hours), f"HOURS={hours}"
raise ValueError("Missing time specifier")


def _find_nearest(
time_map: List[datetime],
time: datetime,
threshold: timedelta = DEFAULT_TIME_DELTA,
) -> int:
nearest_index = -1
nearest_diff = None
for i, t in enumerate(time_map):
diff = abs(time - t)
if diff < threshold and (nearest_diff is None or nearest_diff > diff):
nearest_diff = diff
nearest_index = i
if nearest_diff is None:
raise IndexError(f"{time} is not in the time map")
return nearest_index


def get_restart(
date_dict: DateValues,
obs_name: str,
time_map: List[datetime],
has_refcase: bool,
) -> int:
if date_dict.restart is not None:
return date_dict.restart
if not time_map:
raise ObservationConfigError.with_context(
f"Missing REFCASE or TIME_MAP for observations: {obs_name}",
obs_name,
)

try:
time, date_str = get_time(date_dict, time_map[0])
except ObservationConfigError:
raise
except ValueError as err:
raise ObservationConfigError.with_context(
f"Failed to parse date of {obs_name}", obs_name
) from err

try:
return _find_nearest(time_map, time)
except IndexError as err:
raise ObservationConfigError.with_context(
f"Could not find {time} ({date_str}) in "
f"the time map for observations {obs_name}"
+ (
"The time map is set from the REFCASE keyword. Either "
"the REFCASE has an incorrect/missing date, or the observation "
"is given an incorrect date.)"
if has_refcase
else " (The time map is set from the TIME_MAP "
"keyword. Either the time map file has an "
"incorrect/missing date, or the observation is given an "
"incorrect date."
),
obs_name,
) from err


def make_value_and_std_dev(
observation_dict: SummaryValues,
) -> Tuple[float, float]:
value = observation_dict.value
return (
value,
float(
handle_error_mode(
np.array(value),
observation_dict,
)
),
)


def handle_error_mode(
values: "npt.ArrayLike",
error_dict: ErrorValues,
) -> "npt.NDArray[np.double]":
values = np.asarray(values)
error_mode = error_dict.error_mode
error_min = error_dict.error_min
error = error_dict.error
if error_mode == "ABS":
return np.full(values.shape, error)
elif error_mode == "REL":
return np.abs(values) * error
elif error_mode == "RELMIN":
return np.maximum(np.abs(values) * error, np.full(values.shape, error_min))
raise ObservationConfigError(f"Unknown error mode {error_mode}", error_mode)
Loading

0 comments on commit a5273b5

Please sign in to comment.