Commit b3865308 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch '541-use-tqdm-auto-for-better-notebook-performance' into 'master'

Resolve "Use tqdm.auto for better notebook performance"

Closes #541

See merge request lscsoft/bilby!895
parents 4a979f4c 1f02e8dd
...@@ -6,7 +6,7 @@ import pickle ...@@ -6,7 +6,7 @@ import pickle
import signal import signal
import time import time
import tqdm from tqdm.auto import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from pandas import DataFrame from pandas import DataFrame
...@@ -224,7 +224,7 @@ class Dynesty(NestedSampler): ...@@ -224,7 +224,7 @@ class Dynesty(NestedSampler):
self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive']) self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive'])
if self.kwargs['print_func'] is None: if self.kwargs['print_func'] is None:
self.kwargs['print_func'] = self._print_func self.kwargs['print_func'] = self._print_func
self.pbar = tqdm.tqdm(file=sys.stdout) self.pbar = tqdm(file=sys.stdout)
Sampler._verify_kwargs_against_default_kwargs(self) Sampler._verify_kwargs_against_default_kwargs(self)
def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs): def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs):
......
from __future__ import absolute_import, print_function
from collections import namedtuple from collections import namedtuple
import os import os
import signal import signal
...@@ -12,8 +10,7 @@ from pandas import DataFrame ...@@ -12,8 +10,7 @@ from pandas import DataFrame
from distutils.version import LooseVersion from distutils.version import LooseVersion
import dill as pickle import dill as pickle
from ..utils import ( from ..utils import logger, check_directory_exists_and_if_not_mkdir
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from .base_sampler import MCMCSampler, SamplerError from .base_sampler import MCMCSampler, SamplerError
...@@ -353,7 +350,7 @@ class Emcee(MCMCSampler): ...@@ -353,7 +350,7 @@ class Emcee(MCMCSampler):
self.pos0 = self.sampler.chain[:, -1, :] self.pos0 = self.sampler.chain[:, -1, :]
def run_sampler(self): def run_sampler(self):
tqdm = get_progress_bar() from tqdm.auto import tqdm
sampler_function_kwargs = self.sampler_function_kwargs sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations') iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations iterations -= self._previous_iterations
......
from __future__ import absolute_import, print_function from ..utils import logger
from ..utils import logger, get_progress_bar
import numpy as np import numpy as np
import os import os
from .emcee import Emcee from .emcee import Emcee
...@@ -141,7 +140,7 @@ class Kombine(Emcee): ...@@ -141,7 +140,7 @@ class Kombine(Emcee):
logger.info("Kombine auto-burnin complete. Removing {} samples from chains".format(self.nburn)) logger.info("Kombine auto-burnin complete. Removing {} samples from chains".format(self.nburn))
self._set_pos0_for_resume() self._set_pos0_for_resume()
tqdm = get_progress_bar() from tqdm.auto import tqdm
sampler_function_kwargs = self.sampler_function_kwargs sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations') iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations iterations -= self._previous_iterations
......
...@@ -508,26 +508,6 @@ def get_version_information(): ...@@ -508,26 +508,6 @@ def get_version_information():
print("No version information file '.version' found") print("No version information file '.version' found")
def get_progress_bar(module='tqdm'):
"""
TODO: Write proper docstring
"""
if module in ['tqdm']:
try:
from tqdm import tqdm
except ImportError:
def tqdm(x, *args, **kwargs):
return x
return tqdm
elif module in ['tqdm_notebook']:
try:
from tqdm import tqdm_notebook as tqdm
except ImportError:
def tqdm(x, *args, **kwargs):
return x
return tqdm
def spherical_to_cartesian(radius, theta, phi): def spherical_to_cartesian(radius, theta, phi):
""" Convert from spherical coordinates to cartesian. """ Convert from spherical coordinates to cartesian.
......
from __future__ import division
import sys import sys
import multiprocessing import multiprocessing
from tqdm import tqdm from tqdm.auto import tqdm
import numpy as np import numpy as np
from pandas import DataFrame from pandas import DataFrame
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment