Source code for iccas.queries

from itertools import product
from typing import Dict, Iterable, List, Sequence, Set, Union

import pandas as pd

import iccas as ic
from iccas.charts.common import resample_if_needed
from iccas.types import PandasObj

PREFIX_MAP = {
    "T": "",
    "M": "male_",
    "F": "female_",
}
PREFIX_CODES = set(PREFIX_MAP)
FIELDS = ["cases", "cases_percentage", "deaths", "deaths_percentage", "fatality_rate"]
FIELD_SET = set(FIELDS)

AGE_GROUPS = (
    *(f"{start}-{start + 9}" for start in range(0, 90, 10)),
    ">=90",
    "unknown",
)


[docs]def product_join(*string_iterables, sep: str = "") -> Iterable[str]: return map(sep.join, product(*string_iterables))
def _check_column_name_part(values: Sequence[str], allowed_values: Set[str], kind: str): assert kind in {"prefixes", "fields"} if not values: raise ValueError(f'empty {kind} not allowed; use "*" to select all prefixes') if len(values) != len(set(values)): raise ValueError(f"invalid {kind}: duplicates not allowed") invalid_values = set(values) - allowed_values if invalid_values: raise ValueError(f"invalid {kind}: {invalid_values}") def _check_prefix_codes(prefixes: Sequence[str]): _check_column_name_part(prefixes, PREFIX_CODES, "prefixes") def _check_fields(fields: Sequence[str]): _check_column_name_part(fields, FIELD_SET, "fields")
[docs]def cols(prefixes: str, fields: Union[str, Sequence[str]] = "*") -> List[str]: """ Generates a list of columns by combining prefixes with fields. Args: prefixes: string containing one or multiple of the following characters: - 'm' for males - 'f' for females - 't' for totals (no prefix) - '*' for all fields: values: 'cases', 'deaths', 'cases_percentage', 'deaths_percentage', 'fatality_rate', '*' Returns: a list of string """ field_list: Iterable[str] if isinstance(fields, str): if fields == "*": field_list = list(FIELDS) else: field_list = fields.split() _check_fields(field_list) else: field_list = fields _check_fields(field_list) prefix_list: Iterable[str] if prefixes == "*": prefix_list = PREFIX_MAP.values() else: prefixes = prefixes.upper() _check_prefix_codes(prefixes) prefix_list = [PREFIX_MAP[p] for p in prefixes] return list(product_join(prefix_list, field_list))
[docs]def only_counts(data: pd.DataFrame) -> pd.DataFrame: """Returns only cases and deaths columns (including sex-specific columns), dropping all other columns that are computable from these.""" return data[ [ "cases", "female_cases", "male_cases", "deaths", "female_deaths", "male_deaths", ] ]
[docs]def only_cases(data: pd.DataFrame) -> pd.DataFrame: """ Returns only columns ['cases', 'female_cases', 'male_cases'] """ return data[["cases", "female_cases", "male_cases"]]
[docs]def only_deaths(data: pd.DataFrame) -> pd.DataFrame: """ Returns only columns ['deaths', 'female_deaths', 'male_deaths'] """ return data[["deaths", "female_deaths", "male_deaths"]]
[docs]def age_grouper( cuts: Union[int, Sequence[int]], fmt_last: str = ">={}", ) -> Dict[str, str]: if isinstance(cuts, int): if cuts % 10 != 0 or cuts < 0: raise ValueError() cuts = list(range(0, 100, cuts)) else: for c in cuts: if c < 0 or c % 10 != 0: raise ValueError(f"at least one cut is not multiple of 10: {c}") if any(cuts[i - 1] >= cuts[i] for i in range(1, len(cuts))): raise ValueError(f"cuts are not in increasing order: {cuts}") keys = AGE_GROUPS starts = [0, *cuts[:-1]] widths = [(end - start) // 10 for start, end in zip(starts, cuts)] grouper = {} k = 0 for start, cut, width in zip(starts, cuts, widths): end = cut - 1 value = f"{start}-{end}" for _ in range(width): grouper[keys[k]] = value k += 1 last_value = fmt_last.format(cuts[-1]) for _ in range(10 - len(grouper)): grouper[keys[k]] = last_value k += 1 grouper["unknown"] = "unknown" return grouper
[docs]def aggregate_age_groups( counts: PandasObj, cuts: Union[int, Sequence[int]], fmt_last: str = ">={}", ) -> PandasObj: """ Aggregates counts for different age groups summing them together. Args: counts: can be a Series with age groups as index or a DataFrame with age groups as columns, either in a simple Index or in a MultiIndex (no matter in what level) cuts: a single integer N means "cut each N years"; a sequence of integers determines the start ages of new age groups; 0 is implicitly the start age of the first group, even if not present in ``cuts``. fmt_last: format string for the last "unbounded" age group Returns: A Series/DataFrame with the same "structure" of the input but with aggregated age groups. """ grouper = age_grouper(cuts=cuts, fmt_last=fmt_last) if isinstance(counts, pd.Series): return counts.groupby(grouper).sum() d = counts if "age_group" in counts.columns.names: d = counts.stack("age_group") d = d.reset_index("age_group") result = ( d.assign(age_group=d.age_group.apply(grouper.__getitem__)) .groupby(["date", "age_group"]) .sum() .unstack("age_group") ) if not isinstance(counts.columns, pd.MultiIndex): return result.droplevel(axis=1, level=0) return result
[docs]def get_unknown_sex_count(counts: pd.DataFrame, variable: str) -> pd.DataFrame: """ Returns cases/deaths of unknown sex for each age group """ if variable not in {"cases", "deaths"}: raise ValueError("variable should be 'cases' or 'deaths'") total = counts[variable] sum_of_sexes = counts[f"male_{variable}"] + counts[f"female_{variable}"] return total - sum_of_sexes
[docs]def running_count( counts: PandasObj, window: int, step: int = 1, **resample_kwargs ) -> PandasObj: """ Given counts for cases and/or deaths, returns the number of new cases inside a temporal window of ``window`` days that moves forward by steps of ``step`` days. Args: counts: window: step: Returns: """ if window % step != 0: raise ValueError("'window' must be a multiple of 'step'") r = ic.resample(counts, freq=step, **resample_kwargs) diff_periods = window // step out = r.diff(diff_periods).iloc[diff_periods:].round().astype(int) out.index = out.index.to_period(f'{window}D') return out
[docs]def running_average( counts: PandasObj, window: int, step: int = 1, **resample_kwargs ) -> PandasObj: """ Given counts for cases/deaths, returns the average daily number of new cases/deaths inside a temporal window of ``window``, moving the window ``step`` days a time. Args: counts: window: step: Returns: """ return running_count(counts, window, step, **resample_kwargs) / window
[docs]def count_by_period(counts: PandasObj, freq: Union[str, int]) -> PandasObj: """ Returns a new Series/DataFrame with counts (cases/deaths) by period (e.g. months, weeks, ``n`` days ecc) Args: counts: freq: period frequency parameter (whatever accepted by ``pandas``) Returns: """ if isinstance(freq, int): freq = f'{freq}D' r = ic.resample(counts, freq=freq) r = r.diff().dropna() r.index = r.index.to_period(freq=freq).rename('period') return r
[docs]def average_by_period(counts: PandasObj, freq: Union[str, int]) -> PandasObj: """ Returns a new Series/DataFrame with average counts (cases/deaths) by period (e.g. months, weeks, ``n`` days ecc) Args: counts: freq: period frequency parameter (whatever accepted by ``pandas``) Returns: """ c = count_by_period(counts, freq) lengths = (c.index.end_time - c.index.start_time + pd.Timedelta(nanoseconds=1)).days return c.div(lengths, axis=0)
[docs]def fatality_rate(counts, shift): """ Computes the fatality rate as a ratio between the total number of deaths and the total number of cases ``shift`` days before. ``counts`` is resampled with interpolation if needed. """ resampled = resample_if_needed(counts, freq='1D') shifted_cases = resampled.cases.shift(shift).iloc[shift:] deaths = resampled.deaths.iloc[shift:] cfr = (deaths / shifted_cases) if 'unknown' in cfr.columns: cfr = cfr.drop(columns='unknown') return cfr