import threading
from functools import wraps
from collections.abc import Callable
tls = threading.local()
[docs]
def per_thread_result(func: Callable) -> Callable:
"""
Caches the function result in the thread local storage. Can be used as a decorator::
>>> import time, threading
>>> from concurrent.futures.thread import ThreadPoolExecutor
>>> @per_thread_result
>>> def get_init_data():
... print('get_init_data called')
... return 42
>>> def usage(*args):
... for n in range(3):
... print(f'{threading.get_ident()}: {get_init_data()}')
... time.sleep(0.1)
>>> pool = ThreadPoolExecutor(max_workers=2)
>>> for _ in range(2):
... pool.submit(usage)
get_init_data called
11032: 42
get_init_data called
2092: 42
11032: 42
2092: 42
11032: 42
2092: 42
In the example above, each thread calls the function multiple times. The function is called only on the first time
and the value is saved. On subsequent calls from the same thread, the underlying function will not be called and
the previously saved result will be used.
:param func: Function to wrap.
:return: New function.
"""
global tls
mod_name = func.__module__
if mod_name is None:
full_name = func.__name__
else:
full_name = mod_name + ':' + func.__name__
@wraps(func)
def get_result(*args, **kwargs):
if hasattr(tls, full_name):
return getattr(tls, full_name)
else:
result = func(*args, **kwargs)
setattr(tls, full_name, result)
return result
return get_result