From 3201dc1e10c6ad85eb7e420ca154ef995e2bfb2b Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Tue, 17 Sep 2024 21:13:34 -0700 Subject: [PATCH] more accurate computation --- scripts/estimate_cost.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/scripts/estimate_cost.py b/scripts/estimate_cost.py index 8442c0c197..7163510087 100644 --- a/scripts/estimate_cost.py +++ b/scripts/estimate_cost.py @@ -57,13 +57,17 @@ def aggregate(self) -> Dict[str, ModelCost]: with open(run_spec_path) as f: run_spec = json.load(f) model: str = run_spec["adapter_spec"]["model"] + cost: ModelCost = models_to_costs[model] metrics_path: str = os.path.join(run_path, "stats.json") with open(metrics_path) as f: metrics = json.load(f) + num_prompt_tokens: int = -1 + num_completion_tokens: int = -1 + num_instances: int = -1 + for metric in metrics: - cost: ModelCost = models_to_costs[model] metric_name: str = metric["name"]["name"] # Don't count perturbations @@ -71,11 +75,16 @@ def aggregate(self) -> Dict[str, ModelCost]: continue if metric_name == "num_prompt_tokens": - cost.add_prompt_tokens(metric["sum"]) + num_prompt_tokens = metric["sum"] elif metric_name == "num_completion_tokens": - cost.add_num_completion_tokens(metric["sum"]) + num_completion_tokens = metric["sum"] elif metric_name == "num_instances": - cost.add_num_instances(metric["sum"]) + num_instances = metric["sum"] + + assert num_prompt_tokens >= 0 and num_completion_tokens >= 0 and num_instances >= 0 + cost.add_prompt_tokens(num_prompt_tokens * num_instances) + cost.add_num_completion_tokens(num_completion_tokens * num_instances) + cost.add_num_instances(num_instances) return models_to_costs