Skip to content

chainconsumer.Chain

Bases: ChainConfig

The numerical chain with its configuration.

Source code in src/chainconsumer/chain.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
class Chain(ChainConfig):
    """The numerical chain with its configuration."""

    samples: pd.DataFrame = Field(
        default=...,
        description="The chain data as a pandas DataFrame",
    )
    name: ChainName = Field(
        default=...,
        description="The name of the chain",
    )

    weight_column: ColumnName = Field(
        default="weight",
        description="The name of the weight column, if it exists",
    )
    posterior_column: ColumnName = Field(
        default="log_posterior",
        description="The name of the log posterior column, if it exists",
    )
    walkers: int = Field(
        default=1,
        ge=1,
        description="The number of walkers in the chain",
    )
    grid: bool = Field(
        default=False,
        description="Whether the chain is a sampled grid or not",
    )
    num_free_params: int | None = Field(
        default=None,
        description="The number of free parameters in the chain",
        ge=0,
    )
    num_eff_data_points: float | None = Field(
        default=None,
        description="The number of effective data points",
        ge=0,
    )
    power: float = Field(
        default=1.0,
        description="Raise the posterior surface to this. Useful for inflating or deflating uncertainty for debugging.",
    )
    show_label_in_legend: bool = Field(
        default=True,
        description="Whether to show the label in the legend",
    )
    histogram_relative_height: float = Field(
        default=1.0,
        description="The relative height to plot the marginalised histogram. 1.0 will ensure a normalised histogram.",
        ge=0.0,
    )

    @property
    def data_columns(self) -> list[str]:
        """The columns in the dataframe which are not weights or posteriors."""
        results = []
        for c in self.samples.columns:
            if c in {self.weight_column, self.posterior_column}:
                continue
            if c.lower() in {
                "weight",
                "weights",
                "posterior",
                "posteriors",
                "log_weights",
                "log_posterior",
                "log_posteriors",
            }:
                continue
            results.append(c)
        return results

    @property
    def data_samples(self) -> pd.DataFrame:
        """The subsection of the dataframe with data points (ie excluding weights and posterior)"""
        return self.samples[self.data_columns]

    @property
    def plotting_columns(self) -> list[str]:
        """The columns to be plotted, which are the dataframe columns
        with the weights, posterior and colour columns removed."""
        cols = self.data_columns
        if not self.plot_cloud:
            return cols
        return [c for c in cols if c != self.color_param]

    @property
    def skip(self) -> bool:
        """If the chain will be skipped in plotting because it has nothing to plot."""
        return self.samples.empty or not (self.plot_contour or self.plot_cloud or self.plot_point)

    @property
    def max_posterior_row(self) -> pd.Series | None:
        """The row of samples which correspond to the maximum posterior value.
        None if the posterior is not supplied."""
        if self.posterior_column not in self.samples.columns:
            logging.warning("No posterior column found, cannot find max posterior")
            return None
        argmax = self.samples[self.posterior_column].argmax()
        return self.samples.loc[argmax]  # type: ignore

    @property
    def weights(self) -> np.ndarray:
        """The column of weights in the samples."""
        return self.samples[self.weight_column].to_numpy()

    @property
    def log_posterior(self) -> np.ndarray | None:
        """The column of log posteriors in the samples. None if not set."""
        if self.posterior_column not in self.samples.columns:
            return None
        return self.samples[self.posterior_column].to_numpy()

    @property
    def color_data(self) -> np.ndarray | None:
        """The data from the color column. None if not set."""
        if self.color_param is None:
            return None
        return self.samples[self.color_param].to_numpy()

    @property
    def smooth_value(self) -> int:
        """The smoothing value to use for the histogram. If smooth is set, this is the value.
        If not, it is 0 for gridded data, and 3 for non-gridded data."""
        if self.smooth is not None:
            return self.smooth
        if self.grid:
            return 0
        return 3

    @field_validator("color")
    @classmethod
    def _validate_color(cls, v: str | np.ndarray | list[float] | None) -> str | None:
        if v is None:
            return None
        return colors.format(v)

    @field_validator("samples")
    @classmethod
    def _copy_df(cls, v: pd.DataFrame) -> pd.DataFrame:
        return v.copy()

    @model_validator(mode="after")  # type: ignore
    def _validate_model(self) -> Chain:
        assert not self.samples.empty, "Your chain is empty. This is not ideal."

        # If weights aren't set, add them all as one
        if self.weight_column not in self.samples:
            assert self.weight_column == "weight", (
                f"weight column has been changed to {self.weight_column}, but its not in the dataframe"
            )

            self.samples[self.weight_column] = 1.0
        else:
            assert np.all(self.weights > 0), "Weights must be positive and non-zero"

        for column in self.samples.columns:
            assert isinstance(column, str), f"Column {column} is not a string"
            assert np.all(np.isfinite(self.samples[column])), f"Column {column} has NaN or inf in it"

        # Apply the mean shift if it is set to true
        if self.shift_params:
            for param in self.samples:
                self.samples[param] -= np.average(self.samples[param], weights=self.weights)  # type: ignore

        # Check the walkers
        assert self.samples.shape[0] % self.walkers == 0, (
            f"Chain {self.name} has {self.samples.shape[0]} steps, "
            "which is not divisible by {self.walkers} walkers. This is not good."
        )

        # And if the color_params are set, ensure they're in the dataframe
        if self.color_param is not None:
            assert self.color_param in self.samples.columns, (
                f"Chain {self.name} does not have color parameter {self.color_param}"
            )

        # more nan checks
        if self.num_eff_data_points is not None:
            assert np.isfinite(self.num_eff_data_points), "num_eff_data_points is not finite"

        if self.num_free_params is not None:
            assert np.isfinite(self.num_free_params), "num_free_params is not finite"

        if self.statistics is None:
            if self.multimodal:
                self.statistics = SummaryStatistic.HDI
            else:
                self.statistics = SummaryStatistic.MAX

        if self.multimodal and self.statistics is not SummaryStatistic.HDI:
            raise ValueError(
                f"Chain {self.name} is marked as multimodal but uses {self.statistics.value}; "
                "set statistics=SummaryStatistic.HDI."
            )

        return self

    def get_data(self, column: str) -> pd.Series[float]:
        """Extracts a single columns from the samples dataframe."""
        return self.samples[column]

    @classmethod
    def from_covariance(
        cls,
        mean: np.ndarray | list[float],
        covariance: np.ndarray | list[list[float]],
        columns: list[ColumnName],
        name: ChainName,
        **kwargs: Any,
    ) -> Chain:
        """Generate samples as per mean and covariance supplied. Useful for Fisher matrix forecasts.

        Args:
            mean: The an array of mean values.
            covariance: The 2D array describing the covariance.
                Dimensions should agree with the `mean` input.
            columns: A list of parameter names, one for each column (dimension) in the mean array.
            name: The name of the chain.
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            The generated chain.
        """
        rng = np.random.default_rng()
        samples = rng.multivariate_normal(mean, covariance, size=1000000)  # type: ignore
        df = pd.DataFrame(samples, columns=columns)
        return cls(samples=df, name=name, **kwargs)  # type: ignore

    def divide(self) -> list[Chain]:
        """Returns a ChainConsumer instance containing all the walks of a given chain
        as individual chains themselves.

        This method might be useful if, for example, your chain was made using
        MCMC with 4 walkers. To check the sampling of all 4 walkers agree, you could
        call this to get a ChainConsumer instance with one chain for ech of the
        four walks. If you then plot, hopefully all four contours
        you would see agree.

        Returns:
            One chain per walker, split evenly
        """
        assert self.walkers > 1, "Cannot divide a chain with only one walker"
        assert not self.grid, "Cannot divide a grid chain"

        splits = np.split(self.samples, self.walkers)
        chains = []
        for i, split in enumerate(splits):
            df = pd.DataFrame(split, columns=self.samples.columns)
            options = self.model_dump(exclude={"samples", "name", "walkers"})
            if "color" in options:
                options.pop("color")
            chain = Chain(samples=df, name=f"{self.name} Walker {i}", **options)
            chains.append(chain)

        return chains

    def get_max_posterior_point(self) -> MaxPosterior | None:
        """Returns the maximum posterior point in the chain. If the posterior

        Returns:
            MaxPosterior: The maximum posterior point
        """
        if self.max_posterior_row is None:
            return None
        row = self.max_posterior_row.to_dict()
        log_posterior = row.pop(self.posterior_column)
        row = {k: v for k, v in row.items() if k in self.plotting_columns}
        return MaxPosterior(log_posterior=log_posterior, coordinate=row)

    def get_covariance(self, columns: list[str] | None = None) -> Named2DMatrix:
        """Returns the covariance matrix of the chain.

        Args:
            columns: The columns to use. None means all data columns.

        Returns:
            Named2DMatrix: The covariance matrix
        """
        if columns is None:
            columns = self.data_columns
        cov = np.cov(self.samples[columns], rowvar=False, aweights=self.weights)
        return Named2DMatrix(columns=columns, matrix=cov)

    def get_correlation(self, columns: list[str] | None = None) -> Named2DMatrix:
        """Returns the correlation matrix of the chain.

        Args:
            columns: The columns to use. None means all data columns.

        Returns:
            Named2DMatrix: The correlation matrix
        """
        cov = self.get_covariance(columns)
        diag = np.sqrt(np.diag(cov.matrix))
        divisor = diag[None, :] * diag[:, None]  # type: ignore
        correlations = cov.matrix / divisor
        return Named2DMatrix(columns=cov.columns, matrix=correlations)

    @classmethod
    def from_emcee(
        cls,
        sampler: emcee.EnsembleSampler,
        columns: list[str],
        name: str,
        thin: int = 1,
        discard: int = 0,
        **kwargs: Any,
    ) -> Chain:
        """Constructor from an emcee sampler

        Args:
            sampler: The emcee sampler
            columns: The names of the parameters
            name: The name of the chain
            thin: The thinning to apply to the chain
            discard: The number of steps to discard from the start of the chain
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            A ChainConsumer Chain made from the emcee samples
        """
        chain: np.ndarray = sampler.get_chain(flat=True, thin=thin, discard=discard)  # type: ignore
        df = pd.DataFrame.from_dict({col: val for col, val in zip(columns, chain.T, strict=False)})

        return cls(samples=df, name=name, **kwargs)

    @classmethod
    def from_numpyro(
        cls,
        mcmc: numpyro.infer.MCMC,
        name: str,
        var_names: list[str] | None = None,
        **kwargs: Any,
    ) -> Chain:
        """Constructor from numpyro samples

        Args:
            mcmc: The numpyro sampler
            name: The name of the chain
            var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
            they are excluded from the variables. If empty, all parameters are included.
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            A ChainConsumer Chain made from numpyro samples
        """

        var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys()))
        df = pd.DataFrame.from_dict(
            {key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}
        )
        return cls(samples=df, name=name, **kwargs)

    @classmethod
    def from_arviz(
        cls,
        arviz_id: arviz.InferenceData,
        name: str,
        var_names: list[str] | None = None,
        **kwargs: Any,
    ) -> Chain:
        """Constructor from an arviz InferenceData object

        Args:
            arviz_id: The arviz inference data
            name: The name of the chain
            var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
            they are excluded from the variables. If empty, all parameters are included.
            kwargs: Any other arguments to pass to the Chain constructor.

        Returns:
            A ChainConsumer Chain made from the arviz chain
        """

        import arviz as az

        var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys()))  # type: ignore
        reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior")
        df = reduced_id.to_dataframe().drop(columns=["chain", "draw"], errors="ignore")

        return cls(samples=df, name=name, **kwargs)

samples class-attribute instance-attribute

samples: DataFrame = Field(
    default=...,
    description="The chain data as a pandas DataFrame",
)

name class-attribute instance-attribute

name: ChainName = Field(
    default=..., description="The name of the chain"
)

weight_column class-attribute instance-attribute

weight_column: ColumnName = Field(
    default="weight",
    description="The name of the weight column, if it exists",
)

posterior_column class-attribute instance-attribute

posterior_column: ColumnName = Field(
    default="log_posterior",
    description="The name of the log posterior column, if it exists",
)

walkers class-attribute instance-attribute

walkers: int = Field(
    default=1,
    ge=1,
    description="The number of walkers in the chain",
)

grid class-attribute instance-attribute

grid: bool = Field(
    default=False,
    description="Whether the chain is a sampled grid or not",
)

num_free_params class-attribute instance-attribute

num_free_params: int | None = Field(
    default=None,
    description="The number of free parameters in the chain",
    ge=0,
)

num_eff_data_points class-attribute instance-attribute

num_eff_data_points: float | None = Field(
    default=None,
    description="The number of effective data points",
    ge=0,
)

power class-attribute instance-attribute

power: float = Field(
    default=1.0,
    description="Raise the posterior surface to this. Useful for inflating or deflating uncertainty for debugging.",
)

show_label_in_legend class-attribute instance-attribute

show_label_in_legend: bool = Field(
    default=True,
    description="Whether to show the label in the legend",
)

histogram_relative_height class-attribute instance-attribute

histogram_relative_height: float = Field(
    default=1.0,
    description="The relative height to plot the marginalised histogram. 1.0 will ensure a normalised histogram.",
    ge=0.0,
)

data_columns property

data_columns: list[str]

The columns in the dataframe which are not weights or posteriors.

data_samples property

data_samples: DataFrame

The subsection of the dataframe with data points (ie excluding weights and posterior)

plotting_columns property

plotting_columns: list[str]

The columns to be plotted, which are the dataframe columns with the weights, posterior and colour columns removed.

skip property

skip: bool

If the chain will be skipped in plotting because it has nothing to plot.

max_posterior_row property

max_posterior_row: Series | None

The row of samples which correspond to the maximum posterior value. None if the posterior is not supplied.

weights property

weights: ndarray

The column of weights in the samples.

log_posterior property

log_posterior: ndarray | None

The column of log posteriors in the samples. None if not set.

color_data property

color_data: ndarray | None

The data from the color column. None if not set.

smooth_value property

smooth_value: int

The smoothing value to use for the histogram. If smooth is set, this is the value. If not, it is 0 for gridded data, and 3 for non-gridded data.

get_data

get_data(column: str) -> pd.Series[float]

Extracts a single columns from the samples dataframe.

Source code in src/chainconsumer/chain.py
def get_data(self, column: str) -> pd.Series[float]:
    """Extracts a single columns from the samples dataframe."""
    return self.samples[column]

from_covariance classmethod

from_covariance(
    mean: ndarray | list[float],
    covariance: ndarray | list[list[float]],
    columns: list[ColumnName],
    name: ChainName,
    **kwargs: Any,
) -> Chain

Generate samples as per mean and covariance supplied. Useful for Fisher matrix forecasts.

Parameters:

Name Type Description Default
mean ndarray | list[float]

The an array of mean values.

required
covariance ndarray | list[list[float]]

The 2D array describing the covariance. Dimensions should agree with the mean input.

required
columns list[ColumnName]

A list of parameter names, one for each column (dimension) in the mean array.

required
name ChainName

The name of the chain.

required
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

The generated chain.

Source code in src/chainconsumer/chain.py
@classmethod
def from_covariance(
    cls,
    mean: np.ndarray | list[float],
    covariance: np.ndarray | list[list[float]],
    columns: list[ColumnName],
    name: ChainName,
    **kwargs: Any,
) -> Chain:
    """Generate samples as per mean and covariance supplied. Useful for Fisher matrix forecasts.

    Args:
        mean: The an array of mean values.
        covariance: The 2D array describing the covariance.
            Dimensions should agree with the `mean` input.
        columns: A list of parameter names, one for each column (dimension) in the mean array.
        name: The name of the chain.
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        The generated chain.
    """
    rng = np.random.default_rng()
    samples = rng.multivariate_normal(mean, covariance, size=1000000)  # type: ignore
    df = pd.DataFrame(samples, columns=columns)
    return cls(samples=df, name=name, **kwargs)  # type: ignore

divide

divide() -> list[Chain]

Returns a ChainConsumer instance containing all the walks of a given chain as individual chains themselves.

This method might be useful if, for example, your chain was made using MCMC with 4 walkers. To check the sampling of all 4 walkers agree, you could call this to get a ChainConsumer instance with one chain for ech of the four walks. If you then plot, hopefully all four contours you would see agree.

Returns:

Type Description
list[Chain]

One chain per walker, split evenly

Source code in src/chainconsumer/chain.py
def divide(self) -> list[Chain]:
    """Returns a ChainConsumer instance containing all the walks of a given chain
    as individual chains themselves.

    This method might be useful if, for example, your chain was made using
    MCMC with 4 walkers. To check the sampling of all 4 walkers agree, you could
    call this to get a ChainConsumer instance with one chain for ech of the
    four walks. If you then plot, hopefully all four contours
    you would see agree.

    Returns:
        One chain per walker, split evenly
    """
    assert self.walkers > 1, "Cannot divide a chain with only one walker"
    assert not self.grid, "Cannot divide a grid chain"

    splits = np.split(self.samples, self.walkers)
    chains = []
    for i, split in enumerate(splits):
        df = pd.DataFrame(split, columns=self.samples.columns)
        options = self.model_dump(exclude={"samples", "name", "walkers"})
        if "color" in options:
            options.pop("color")
        chain = Chain(samples=df, name=f"{self.name} Walker {i}", **options)
        chains.append(chain)

    return chains

get_max_posterior_point

get_max_posterior_point() -> MaxPosterior | None

Returns the maximum posterior point in the chain. If the posterior

Returns:

Name Type Description
MaxPosterior MaxPosterior | None

The maximum posterior point

Source code in src/chainconsumer/chain.py
def get_max_posterior_point(self) -> MaxPosterior | None:
    """Returns the maximum posterior point in the chain. If the posterior

    Returns:
        MaxPosterior: The maximum posterior point
    """
    if self.max_posterior_row is None:
        return None
    row = self.max_posterior_row.to_dict()
    log_posterior = row.pop(self.posterior_column)
    row = {k: v for k, v in row.items() if k in self.plotting_columns}
    return MaxPosterior(log_posterior=log_posterior, coordinate=row)

get_covariance

get_covariance(
    columns: list[str] | None = None,
) -> Named2DMatrix

Returns the covariance matrix of the chain.

Parameters:

Name Type Description Default
columns list[str] | None

The columns to use. None means all data columns.

None

Returns:

Name Type Description
Named2DMatrix Named2DMatrix

The covariance matrix

Source code in src/chainconsumer/chain.py
def get_covariance(self, columns: list[str] | None = None) -> Named2DMatrix:
    """Returns the covariance matrix of the chain.

    Args:
        columns: The columns to use. None means all data columns.

    Returns:
        Named2DMatrix: The covariance matrix
    """
    if columns is None:
        columns = self.data_columns
    cov = np.cov(self.samples[columns], rowvar=False, aweights=self.weights)
    return Named2DMatrix(columns=columns, matrix=cov)

get_correlation

get_correlation(
    columns: list[str] | None = None,
) -> Named2DMatrix

Returns the correlation matrix of the chain.

Parameters:

Name Type Description Default
columns list[str] | None

The columns to use. None means all data columns.

None

Returns:

Name Type Description
Named2DMatrix Named2DMatrix

The correlation matrix

Source code in src/chainconsumer/chain.py
def get_correlation(self, columns: list[str] | None = None) -> Named2DMatrix:
    """Returns the correlation matrix of the chain.

    Args:
        columns: The columns to use. None means all data columns.

    Returns:
        Named2DMatrix: The correlation matrix
    """
    cov = self.get_covariance(columns)
    diag = np.sqrt(np.diag(cov.matrix))
    divisor = diag[None, :] * diag[:, None]  # type: ignore
    correlations = cov.matrix / divisor
    return Named2DMatrix(columns=cov.columns, matrix=correlations)

from_emcee classmethod

from_emcee(
    sampler: EnsembleSampler,
    columns: list[str],
    name: str,
    thin: int = 1,
    discard: int = 0,
    **kwargs: Any,
) -> Chain

Constructor from an emcee sampler

Parameters:

Name Type Description Default
sampler EnsembleSampler

The emcee sampler

required
columns list[str]

The names of the parameters

required
name str

The name of the chain

required
thin int

The thinning to apply to the chain

1
discard int

The number of steps to discard from the start of the chain

0
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

A ChainConsumer Chain made from the emcee samples

Source code in src/chainconsumer/chain.py
@classmethod
def from_emcee(
    cls,
    sampler: emcee.EnsembleSampler,
    columns: list[str],
    name: str,
    thin: int = 1,
    discard: int = 0,
    **kwargs: Any,
) -> Chain:
    """Constructor from an emcee sampler

    Args:
        sampler: The emcee sampler
        columns: The names of the parameters
        name: The name of the chain
        thin: The thinning to apply to the chain
        discard: The number of steps to discard from the start of the chain
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        A ChainConsumer Chain made from the emcee samples
    """
    chain: np.ndarray = sampler.get_chain(flat=True, thin=thin, discard=discard)  # type: ignore
    df = pd.DataFrame.from_dict({col: val for col, val in zip(columns, chain.T, strict=False)})

    return cls(samples=df, name=name, **kwargs)

from_numpyro classmethod

from_numpyro(
    mcmc: MCMC,
    name: str,
    var_names: list[str] | None = None,
    **kwargs: Any,
) -> Chain

Constructor from numpyro samples

Parameters:

Name Type Description Default
mcmc MCMC

The numpyro sampler

required
name str

The name of the chain

required
var_names list[str] | None

The names of the parameters to include in the chain. If the entries of var_names start with ~,

None
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

A ChainConsumer Chain made from numpyro samples

Source code in src/chainconsumer/chain.py
@classmethod
def from_numpyro(
    cls,
    mcmc: numpyro.infer.MCMC,
    name: str,
    var_names: list[str] | None = None,
    **kwargs: Any,
) -> Chain:
    """Constructor from numpyro samples

    Args:
        mcmc: The numpyro sampler
        name: The name of the chain
        var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
        they are excluded from the variables. If empty, all parameters are included.
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        A ChainConsumer Chain made from numpyro samples
    """

    var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys()))
    df = pd.DataFrame.from_dict(
        {key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}
    )
    return cls(samples=df, name=name, **kwargs)

from_arviz classmethod

from_arviz(
    arviz_id: InferenceData,
    name: str,
    var_names: list[str] | None = None,
    **kwargs: Any,
) -> Chain

Constructor from an arviz InferenceData object

Parameters:

Name Type Description Default
arviz_id InferenceData

The arviz inference data

required
name str

The name of the chain

required
var_names list[str] | None

The names of the parameters to include in the chain. If the entries of var_names start with ~,

None
kwargs Any

Any other arguments to pass to the Chain constructor.

{}

Returns:

Type Description
Chain

A ChainConsumer Chain made from the arviz chain

Source code in src/chainconsumer/chain.py
@classmethod
def from_arviz(
    cls,
    arviz_id: arviz.InferenceData,
    name: str,
    var_names: list[str] | None = None,
    **kwargs: Any,
) -> Chain:
    """Constructor from an arviz InferenceData object

    Args:
        arviz_id: The arviz inference data
        name: The name of the chain
        var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
        they are excluded from the variables. If empty, all parameters are included.
        kwargs: Any other arguments to pass to the Chain constructor.

    Returns:
        A ChainConsumer Chain made from the arviz chain
    """

    import arviz as az

    var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys()))  # type: ignore
    reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior")
    df = reduced_id.to_dataframe().drop(columns=["chain", "draw"], errors="ignore")

    return cls(samples=df, name=name, **kwargs)