Skip to content

chainconsumer.Chain

Bases: ChainConfig

The numerical chain with its configuration.

Source code in src/chainconsumer/chain.py
class Chain(ChainConfig):
    """The numerical chain with its configuration."""

    samples: pd.DataFrame = Field(
        default=...,
        description="The chain data as a pandas DataFrame",
    )
    name: ChainName = Field(
        default=...,
        description="The name of the chain",
    )

    weight_column: ColumnName = Field(
        default="weight",
        description="The name of the weight column, if it exists",
    )
    posterior_column: ColumnName = Field(
        default="log_posterior",
        description="The name of the log posterior column, if it exists",
    )
    walkers: int = Field(
        default=1,
        ge=1,
        description="The number of walkers in the chain",
    )
    grid: bool = Field(
        default=False,
        description="Whether the chain is a sampled grid or not",
    )
    num_free_params: int | None = Field(
        default=None,
        description="The number of free parameters in the chain",
        ge=0,
    )
    num_eff_data_points: float | None = Field(
        default=None,
        description="The number of effective data points",
        ge=0,
    )
    power: float = Field(
        default=1.0,
        description="Raise the posterior surface to this. Useful for inflating or deflating uncertainty for debugging.",
    )
    show_label_in_legend: bool = Field(
        default=True,
        description="Whether to show the label in the legend",
    )
    histogram_relative_height: float = Field(
        default=1.0,
        description="The relative height to plot the marginalised histogram. 1.0 will ensure a normalised histogram.",
        ge=0.0,
    )

    @property
    def data_columns(self) -> list[str]:
        """The columns in the dataframe which are not weights or posteriors."""
        results = []
        for c in self.samples.columns:
            if c in {self.weight_column, self.posterior_column}:
                continue
            if c.lower() in {
                "weight",
                "weights",
                "posterior",
                "posteriors",
                "log_weights",
                "log_posterior",
                "log_posteriors",
            }:
                continue
            results.append(c)
        return results

    @property
    def data_samples(self) -> pd.DataFrame:
        """The subsection of the dataframe with data points (ie excluding weights and posterior)"""
        return self.samples[self.data_columns]

    @property
    def plotting_columns(self) -> list[str]:
        """The columns to be plotted, which are the dataframe columns
        with the weights, posterior and colour coloumns removed."""
        cols = self.data_columns
        if not self.plot_cloud:
            return cols
        return [c for c in cols if c != self.color_param]

    @property
    def skip(self) -> bool:
        """If the chain will be skipped in plotting because it has nothing to plot."""
        return self.samples.empty or not (self.plot_contour or self.plot_cloud or self.plot_point)

    @property
    def max_posterior_row(self) -> pd.Series | None:
        """The row of samples which correspond to the maximum posterior value.
        None if the posterior is not supplied."""
        if self.posterior_column not in self.samples.columns:
            logging.warning("No posterior column found, cannot find max posterior")
            return None
        argmax = self.samples[self.posterior_column].argmax()
        return self.samples.loc[argmax]  # type: ignore

    @property
    def weights(self) -> np.ndarray:
        """The column of weights in the samples."""
        return self.samples[self.weight_column].to_numpy()

    @property
    def log_posterior(self) -> np.ndarray | None:
        """The column of log posteriors in the samples. None if not set."""
        if self.posterior_column not in self.samples.columns:
            return None
        return self.samples[self.posterior_column].to_numpy()

    @property
    def color_data(self) -> np.ndarray | None:
        """The data from the color column. None if not set."""
        if self.color_param is None:
            return None
        return self.samples[self.color_param].to_numpy()

    @field_validator("color")
    @classmethod
    def _validate_color(cls, v: str | np.ndarray | list[float] | None) -> str | None:
        if v is None:
            return None
        return colors.format(v)

    @field_validator("samples")
    @classmethod
    def _copy_df(cls, v: pd.DataFrame) -> pd.DataFrame:
        return v.copy()

    @model_validator(mode="after")  # type: ignore
    def _validate_model(self) -> Chain:
        assert not self.samples.empty, "Your chain is empty. This is not ideal."

        # If weights aren't set, add them all as one
        if self.weight_column not in self.samples:
            assert (
                self.weight_column == "weight"
            ), f"weight column has been changed to {self.weight_column}, but its not in the dataframe"

            self.samples[self.weight_column] = 1.0
        else:
            assert np.all(self.weights > 0), "Weights must be positive and non-zero"

        for column in self.samples.columns:
            assert isinstance(column, str), f"Column {column} is not a string"
            assert np.all(np.isfinite(self.samples[column])), f"Column {column} has NaN or inf in it"

        # Apply the mean shift if it is set to true
        if self.shift_params:
            for param in self.samples:
                self.samples[param] -= np.average(self.samples[param], weights=self.weights)  # type: ignore

        # Check the walkers
        assert self.samples.shape[0] % self.walkers == 0, (
            f"Chain {self.name} has {self.samples.shape[0]} steps, "
            "which is not divisible by {self.walkers} walkers. This is not good."
        )

        # And if the color_params are set, ensure they're in the dataframe
        if self.color_param is not None:
            assert (
                self.color_param in self.samples.columns
            ), f"Chain {self.name} does not have color parameter {self.color_param}"

        # more nan checks
        if self.num_eff_data_points is not None:
            assert np.isfinite(self.num_eff_data_points), "num_eff_data_points is not finite"

        if self.num_free_params is not None:
            assert np.isfinite(self.num_free_params), "num_free_params is not finite"

        return self

    def get_data(self, column: str) -> pd.Series[float]:
        """Extracts a single columns from the samples dataframe."""
        return self.samples[column]

    @classmethod
    def from_covariance(
        cls,
        mean: np.ndarray | list[float],
        covariance: np.ndarray | list[list[float]],
        columns: list[ColumnName],
        name: ChainName,
        **kwargs: Any,
    ) -> Chain:
        """Generate samples as per mean and covariance supplied. Useful for Fisher matrix forecasts.

        Args:
            mean: The an array of mean values.
            covariance: The 2D array describing the covariance.
                Dimensions should agree with the `mean` input.
            columns: A list of parameter names, one for each column (dimension) in the mean array.
            name: The name of the chain.
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            The generated chain.
        """
        rng = np.random.default_rng()
        samples = rng.multivariate_normal(mean, covariance, size=1000000)  # type: ignore
        df = pd.DataFrame(samples, columns=columns)
        return cls(samples=df, name=name, **kwargs)  # type: ignore

    def divide(self) -> list[Chain]:
        """Returns a ChainConsumer instance containing all the walks of a given chain
        as individual chains themselves.

        This method might be useful if, for example, your chain was made using
        MCMC with 4 walkers. To check the sampling of all 4 walkers agree, you could
        call this to get a ChainConsumer instance with one chain for ech of the
        four walks. If you then plot, hopefully all four contours
        you would see agree.

        Returns:
            One chain per walker, split evenly
        """
        assert self.walkers > 1, "Cannot divide a chain with only one walker"
        assert not self.grid, "Cannot divide a grid chain"

        splits = np.split(self.samples, self.walkers)
        chains = []
        for i, split in enumerate(splits):
            df = pd.DataFrame(split, columns=self.samples.columns)
            options = self.model_dump(exclude={"samples", "name", "walkers"})
            if "color" in options:
                options.pop("color")
            chain = Chain(samples=df, name=f"{self.name} Walker {i}", **options)
            chains.append(chain)

        return chains

    def get_max_posterior_point(self) -> MaxPosterior | None:
        """Returns the maximum posterior point in the chain. If the posterior

        Returns:
            MaxPosterior: The maximum posterior point
        """
        if self.max_posterior_row is None:
            return None
        row = self.max_posterior_row.to_dict()
        log_posterior = row.pop(self.posterior_column)
        row = {k: v for k, v in row.items() if k in self.plotting_columns}
        return MaxPosterior(log_posterior=log_posterior, coordinate=row)

    def get_covariance(self, columns: list[str] | None = None) -> Named2DMatrix:
        """Returns the covariance matrix of the chain.

        Args:
            columns: The columns to use. None means all data columns.

        Returns:
            Named2DMatrix: The covariance matrix
        """
        if columns is None:
            columns = self.data_columns
        cov = np.cov(self.samples[columns], rowvar=False, aweights=self.weights)
        return Named2DMatrix(columns=columns, matrix=cov)

    def get_correlation(self, columns: list[str] | None = None) -> Named2DMatrix:
        """Returns the correlation matrix of the chain.

        Args:
            columns: The columns to use. None means all data columns.

        Returns:
            Named2DMatrix: The correlation matrix
        """
        cov = self.get_covariance(columns)
        diag = np.sqrt(np.diag(cov.matrix))
        divisor = diag[None, :] * diag[:, None]  # type: ignore
        correlations = cov.matrix / divisor
        return Named2DMatrix(columns=cov.columns, matrix=correlations)

    @classmethod
    def from_emcee(
        cls,
        sampler: emcee.EnsembleSampler,
        columns: list[str],
        name: str,
        thin: int = 1,
        discard: int = 0,
        **kwargs: Any,
    ) -> Chain:
        """Constructor from an emcee sampler

        Args:
            sampler: The emcee sampler
            columns: The names of the parameters
            name: The name of the chain
            thin: The thinning to apply to the chain
            discard: The number of steps to discard from the start of the chain
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            A ChainConsumer Chain made from the emcee samples
        """
        chain: np.ndarray = sampler.get_chain(flat=True, thin=thin, discard=discard)  # type: ignore
        df = pd.DataFrame.from_dict({col: val for col, val in zip(columns, chain.T)})

        return cls(samples=df, name=name, **kwargs)

    @classmethod
    def from_numpyro(
        cls,
        mcmc: numpyro.infer.MCMC,
        name: str,
        var_names: list[str] | None = None,
        **kwargs: Any,
    ) -> Chain:
        """Constructor from numpyro samples

        Args:
            mcmc: The numpyro sampler
            name: The name of the chain
            var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
            they are excluded from the variables. If empty, all parameters are included.
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            A ChainConsumer Chain made from numpyro samples
        """

        var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys()))
        df = pd.DataFrame.from_dict(
            {key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}
        )
        return cls(samples=df, name=name, **kwargs)

    @classmethod
    def from_arviz(
        cls,
        arviz_id: arviz.InferenceData,
        name: str,
        var_names: list[str] | None = None,
        **kwargs: Any,
    ) -> Chain:
        """Constructor from an arviz InferenceData object

        Args:
            arviz_id: The arviz inference data
            name: The name of the chain
            var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
            they are excluded from the variables. If empty, all parameters are included.
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            A ChainConsumer Chain made from the arviz chain
        """

        import arviz as az

        var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys()))
        reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior")
        df = reduced_id.to_dataframe().drop(columns=["chain", "draw"])

        return cls(samples=df, name=name, **kwargs)

samples class-attribute instance-attribute

samples: DataFrame = Field(default=..., description='The chain data as a pandas DataFrame')

name class-attribute instance-attribute

name: ChainName = Field(default=..., description='The name of the chain')

weight_column class-attribute instance-attribute

weight_column: ColumnName = Field(default='weight', description='The name of the weight column, if it exists')

posterior_column class-attribute instance-attribute

posterior_column: ColumnName = Field(default='log_posterior', description='The name of the log posterior column, if it exists')

walkers class-attribute instance-attribute

walkers: int = Field(default=1, ge=1, description='The number of walkers in the chain')

grid class-attribute instance-attribute

grid: bool = Field(default=False, description='Whether the chain is a sampled grid or not')

num_free_params class-attribute instance-attribute

num_free_params: int | None = Field(default=None, description='The number of free parameters in the chain', ge=0)

num_eff_data_points class-attribute instance-attribute

num_eff_data_points: float | None = Field(default=None, description='The number of effective data points', ge=0)

power class-attribute instance-attribute

power: float = Field(default=1.0, description='Raise the posterior surface to this. Useful for inflating or deflating uncertainty for debugging.')

show_label_in_legend class-attribute instance-attribute

show_label_in_legend: bool = Field(default=True, description='Whether to show the label in the legend')

histogram_relative_height class-attribute instance-attribute

histogram_relative_height: float = Field(default=1.0, description='The relative height to plot the marginalised histogram. 1.0 will ensure a normalised histogram.', ge=0.0)

data_columns property

data_columns: list[str]

The columns in the dataframe which are not weights or posteriors.

data_samples property

data_samples: DataFrame

The subsection of the dataframe with data points (ie excluding weights and posterior)

plotting_columns property

plotting_columns: list[str]

The columns to be plotted, which are the dataframe columns with the weights, posterior and colour coloumns removed.

skip property

skip: bool

If the chain will be skipped in plotting because it has nothing to plot.

max_posterior_row property

max_posterior_row: Series | None

The row of samples which correspond to the maximum posterior value. None if the posterior is not supplied.

weights property

weights: ndarray

The column of weights in the samples.

log_posterior property

log_posterior: ndarray | None

The column of log posteriors in the samples. None if not set.

color_data property

color_data: ndarray | None

The data from the color column. None if not set.

get_data

get_data(column: str) -> pd.Series[float]

Extracts a single columns from the samples dataframe.

Source code in src/chainconsumer/chain.py
def get_data(self, column: str) -> pd.Series[float]:
    """Extracts a single columns from the samples dataframe."""
    return self.samples[column]

from_covariance classmethod

from_covariance(mean: np.ndarray | list[float], covariance: np.ndarray | list[list[float]], columns: list[ColumnName], name: ChainName, **kwargs: Any) -> Chain

Generate samples as per mean and covariance supplied. Useful for Fisher matrix forecasts.

Parameters:

Name Type Description Default
mean ndarray | list[float]

The an array of mean values.

required
covariance ndarray | list[list[float]]

The 2D array describing the covariance. Dimensions should agree with the mean input.

required
columns list[ColumnName]

A list of parameter names, one for each column (dimension) in the mean array.

required
name ChainName

The name of the chain.

required
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

The generated chain.

Source code in src/chainconsumer/chain.py
@classmethod
def from_covariance(
    cls,
    mean: np.ndarray | list[float],
    covariance: np.ndarray | list[list[float]],
    columns: list[ColumnName],
    name: ChainName,
    **kwargs: Any,
) -> Chain:
    """Generate samples as per mean and covariance supplied. Useful for Fisher matrix forecasts.

    Args:
        mean: The an array of mean values.
        covariance: The 2D array describing the covariance.
            Dimensions should agree with the `mean` input.
        columns: A list of parameter names, one for each column (dimension) in the mean array.
        name: The name of the chain.
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        The generated chain.
    """
    rng = np.random.default_rng()
    samples = rng.multivariate_normal(mean, covariance, size=1000000)  # type: ignore
    df = pd.DataFrame(samples, columns=columns)
    return cls(samples=df, name=name, **kwargs)  # type: ignore

divide

divide() -> list[Chain]

Returns a ChainConsumer instance containing all the walks of a given chain as individual chains themselves.

This method might be useful if, for example, your chain was made using MCMC with 4 walkers. To check the sampling of all 4 walkers agree, you could call this to get a ChainConsumer instance with one chain for ech of the four walks. If you then plot, hopefully all four contours you would see agree.

Returns:

Type Description
list[Chain]

One chain per walker, split evenly

Source code in src/chainconsumer/chain.py
def divide(self) -> list[Chain]:
    """Returns a ChainConsumer instance containing all the walks of a given chain
    as individual chains themselves.

    This method might be useful if, for example, your chain was made using
    MCMC with 4 walkers. To check the sampling of all 4 walkers agree, you could
    call this to get a ChainConsumer instance with one chain for ech of the
    four walks. If you then plot, hopefully all four contours
    you would see agree.

    Returns:
        One chain per walker, split evenly
    """
    assert self.walkers > 1, "Cannot divide a chain with only one walker"
    assert not self.grid, "Cannot divide a grid chain"

    splits = np.split(self.samples, self.walkers)
    chains = []
    for i, split in enumerate(splits):
        df = pd.DataFrame(split, columns=self.samples.columns)
        options = self.model_dump(exclude={"samples", "name", "walkers"})
        if "color" in options:
            options.pop("color")
        chain = Chain(samples=df, name=f"{self.name} Walker {i}", **options)
        chains.append(chain)

    return chains

get_max_posterior_point

get_max_posterior_point() -> MaxPosterior | None

Returns the maximum posterior point in the chain. If the posterior

Returns:

Name Type Description
MaxPosterior MaxPosterior | None

The maximum posterior point

Source code in src/chainconsumer/chain.py
def get_max_posterior_point(self) -> MaxPosterior | None:
    """Returns the maximum posterior point in the chain. If the posterior

    Returns:
        MaxPosterior: The maximum posterior point
    """
    if self.max_posterior_row is None:
        return None
    row = self.max_posterior_row.to_dict()
    log_posterior = row.pop(self.posterior_column)
    row = {k: v for k, v in row.items() if k in self.plotting_columns}
    return MaxPosterior(log_posterior=log_posterior, coordinate=row)

get_covariance

get_covariance(columns: list[str] | None = None) -> Named2DMatrix

Returns the covariance matrix of the chain.

Parameters:

Name Type Description Default
columns list[str] | None

The columns to use. None means all data columns.

None

Returns:

Name Type Description
Named2DMatrix Named2DMatrix

The covariance matrix

Source code in src/chainconsumer/chain.py
def get_covariance(self, columns: list[str] | None = None) -> Named2DMatrix:
    """Returns the covariance matrix of the chain.

    Args:
        columns: The columns to use. None means all data columns.

    Returns:
        Named2DMatrix: The covariance matrix
    """
    if columns is None:
        columns = self.data_columns
    cov = np.cov(self.samples[columns], rowvar=False, aweights=self.weights)
    return Named2DMatrix(columns=columns, matrix=cov)

get_correlation

get_correlation(columns: list[str] | None = None) -> Named2DMatrix

Returns the correlation matrix of the chain.

Parameters:

Name Type Description Default
columns list[str] | None

The columns to use. None means all data columns.

None

Returns:

Name Type Description
Named2DMatrix Named2DMatrix

The correlation matrix

Source code in src/chainconsumer/chain.py
def get_correlation(self, columns: list[str] | None = None) -> Named2DMatrix:
    """Returns the correlation matrix of the chain.

    Args:
        columns: The columns to use. None means all data columns.

    Returns:
        Named2DMatrix: The correlation matrix
    """
    cov = self.get_covariance(columns)
    diag = np.sqrt(np.diag(cov.matrix))
    divisor = diag[None, :] * diag[:, None]  # type: ignore
    correlations = cov.matrix / divisor
    return Named2DMatrix(columns=cov.columns, matrix=correlations)

from_emcee classmethod

from_emcee(sampler: emcee.EnsembleSampler, columns: list[str], name: str, thin: int = 1, discard: int = 0, **kwargs: Any) -> Chain

Constructor from an emcee sampler

Parameters:

Name Type Description Default
sampler EnsembleSampler

The emcee sampler

required
columns list[str]

The names of the parameters

required
name str

The name of the chain

required
thin int

The thinning to apply to the chain

1
discard int

The number of steps to discard from the start of the chain

0
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

A ChainConsumer Chain made from the emcee samples

Source code in src/chainconsumer/chain.py
@classmethod
def from_emcee(
    cls,
    sampler: emcee.EnsembleSampler,
    columns: list[str],
    name: str,
    thin: int = 1,
    discard: int = 0,
    **kwargs: Any,
) -> Chain:
    """Constructor from an emcee sampler

    Args:
        sampler: The emcee sampler
        columns: The names of the parameters
        name: The name of the chain
        thin: The thinning to apply to the chain
        discard: The number of steps to discard from the start of the chain
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        A ChainConsumer Chain made from the emcee samples
    """
    chain: np.ndarray = sampler.get_chain(flat=True, thin=thin, discard=discard)  # type: ignore
    df = pd.DataFrame.from_dict({col: val for col, val in zip(columns, chain.T)})

    return cls(samples=df, name=name, **kwargs)

from_numpyro classmethod

from_numpyro(mcmc: numpyro.infer.MCMC, name: str, var_names: list[str] | None = None, **kwargs: Any) -> Chain

Constructor from numpyro samples

Parameters:

Name Type Description Default
mcmc MCMC

The numpyro sampler

required
name str

The name of the chain

required
var_names list[str] | None

The names of the parameters to include in the chain. If the entries of var_names start with ~,

None
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

A ChainConsumer Chain made from numpyro samples

Source code in src/chainconsumer/chain.py
@classmethod
def from_numpyro(
    cls,
    mcmc: numpyro.infer.MCMC,
    name: str,
    var_names: list[str] | None = None,
    **kwargs: Any,
) -> Chain:
    """Constructor from numpyro samples

    Args:
        mcmc: The numpyro sampler
        name: The name of the chain
        var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
        they are excluded from the variables. If empty, all parameters are included.
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        A ChainConsumer Chain made from numpyro samples
    """

    var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys()))
    df = pd.DataFrame.from_dict(
        {key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}
    )
    return cls(samples=df, name=name, **kwargs)

from_arviz classmethod

from_arviz(arviz_id: arviz.InferenceData, name: str, var_names: list[str] | None = None, **kwargs: Any) -> Chain

Constructor from an arviz InferenceData object

Parameters:

Name Type Description Default
arviz_id InferenceData

The arviz inference data

required
name str

The name of the chain

required
var_names list[str] | None

The names of the parameters to include in the chain. If the entries of var_names start with ~,

None
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

A ChainConsumer Chain made from the arviz chain

Source code in src/chainconsumer/chain.py
@classmethod
def from_arviz(
    cls,
    arviz_id: arviz.InferenceData,
    name: str,
    var_names: list[str] | None = None,
    **kwargs: Any,
) -> Chain:
    """Constructor from an arviz InferenceData object

    Args:
        arviz_id: The arviz inference data
        name: The name of the chain
        var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
        they are excluded from the variables. If empty, all parameters are included.
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        A ChainConsumer Chain made from the arviz chain
    """

    import arviz as az

    var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys()))
    reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior")
    df = reduced_id.to_dataframe().drop(columns=["chain", "draw"])

    return cls(samples=df, name=name, **kwargs)