Skip to content

Add PSIS k-hat diagnostic for variational inference#2139

Open
michaelellis003 wants to merge 1 commit intopyro-ppl:masterfrom
michaelellis003:add-psis-diagnostic
Open

Add PSIS k-hat diagnostic for variational inference#2139
michaelellis003 wants to merge 1 commit intopyro-ppl:masterfrom
michaelellis003:add-psis-diagnostic

Conversation

@michaelellis003
Copy link

Summary

Implements the Pareto Smoothed Importance Sampling (PSIS) k-hat diagnostic for evaluating variational approximation quality, as requested in #1804.

  • psis_diagnostic(rng_key, param_map, model, guide, *args) computes the k-hat statistic by fitting a Generalized Pareto Distribution to the upper tail of importance weights
  • k < 0.5 indicates a good guide, 0.5–0.7 is marginal, k >= 0.7 is unreliable
  • GPD fitting uses Zhang & Stephens (2009) with prior regularization from Vehtari et al. (2024), matching both Pyro's psis_diagnostic and Vehtari's reference implementation to ~1e-15

Changes

  • numpyro/infer/importance.py (new): GPD fitting (_fit_generalized_pareto), PSIS tail extraction (_psis_khat), and public API (psis_diagnostic) with batched evaluation via chunk_size
  • test/infer/test_importance.py (new): 36 tests covering GPD parameter recovery, cross-implementation reference values (precomputed from Vehtari's gpdfitnew and Pyro 1.9.1), regime classification against paper thresholds, batching correctness, edge cases, and SVI integration
  • numpyro/infer/__init__.py: Export psis_diagnostic
  • docs/source/utilities.rst: Add API documentation entry

References

  • Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. ICML.
  • Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2024). Pareto smoothed importance sampling. JMLR, 25(72):1-58.

Fixes #1804

@Qazalbash Qazalbash requested a review from fehiepsi February 14, 2026 07:02
@Qazalbash Qazalbash added the enhancement New feature or request label Feb 14, 2026
Implement Pareto Smoothed Importance Sampling (PSIS) diagnostic to
evaluate variational approximation quality, as requested in pyro-ppl#1804.

The k-hat statistic is the shape parameter of a Generalized Pareto
Distribution fitted to the upper tail of importance weights. It
indicates whether the guide is a reliable approximation:
  k < 0.5: good (finite variance)
  0.5 <= k < 0.7: marginal (finite mean)
  k >= 0.7: unreliable

GPD fitting uses Zhang & Stephens (2009) with prior regularization
from Vehtari et al. (2024), matching Pyro's implementation and
Vehtari's reference code to ~1e-15.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Pareto Smoothed Importance Sampling (PSIS) diagnostic method

2 participants

Comments