Multiscale Modeling of Fluid Materials, Department of Engineering Physics and Computation, TUM School of Engineering and Design, Technical University of Munich, Germany
Paul Fuchs
Multiscale Modeling of Fluid Materials, Department of Engineering Physics and Computation, TUM School of Engineering and Design, Technical University of Munich, Germany
Ana Cukarska
Multiscale Modeling of Fluid Materials, Department of Engineering Physics and Computation, TUM School of Engineering and Design, Technical University of Munich, Germany
Julija Zavadlav
Multiscale Modeling of Fluid Materials, Department of Engineering Physics and Computation, TUM School of Engineering and Design, Technical University of Munich, Germany; Munich Data Science Institute, Technical University of Munich, Germany; Corresponding author at: Multiscale Modeling of Fluid Materials, Department of Engineering Physics and Computation, TUM School of Engineering and Design, Technical University of Munich, Germany.
We present JaxSGMC, an application-agnostic library for stochastic gradient Markov chain Monte Carlo (SG-MCMC) in JAX. SG-MCMC schemes are uncertainty quantification (UQ) methods that scale to large datasets and high-dimensional models, enabling trustworthy neural network predictions via Bayesian deep learning. JaxSGMC implements several state-of-the-art SG-MCMC samplers to promote UQ in deep learning by reducing the barriers of entry for switching from stochastic optimization to SG-MCMC sampling. Additionally, JaxSGMC allows users to build custom samplers from standard SG-MCMC building blocks. Due to this modular structure, we anticipate that JaxSGMC will accelerate research into novel SG-MCMC schemes and facilitate their application across a broad range of domains.