Python class decorator arguments

23,028

Solution 1

@Cache(max_hits=100, timeout=50) calls __init__(max_hits=100, timeout=50), so you aren't satisfying the function argument.

You could implement your decorator via a wrapper method that detected whether a function was present. If it finds a function, it can return the Cache object. Otherwise, it can return a wrapper function that will be used as the decorator.

class _Cache(object):
    def __init__(self, function, max_hits=10, timeout=5):
        self.function = function
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, *args):
        # Here the code returning the correct thing.

# wrap _Cache to allow for deferred calling
def Cache(function=None, max_hits=10, timeout=5):
    if function:
        return _Cache(function)
    else:
        def wrapper(function):
            return _Cache(function, max_hits, timeout)

        return wrapper

@Cache
def double(x):
    return x * 2

@Cache(max_hits=100, timeout=50)
def double(x):
    return x * 2

Solution 2

@Cache
def double(...): 
   ...

is equivalent to

def double(...):
   ...
double=Cache(double)

While

@Cache(max_hits=100, timeout=50)
def double(...):
   ...

is equivalent to

def double(...):
    ...
double = Cache(max_hits=100, timeout=50)(double)

Cache(max_hits=100, timeout=50)(double) has very different semantics than Cache(double).

It's unwise to try to make Cache handle both use cases.

You could instead use a decorator factory that can take optional max_hits and timeout arguments, and returns a decorator:

class Cache(object):
    def __init__(self, function, max_hits=10, timeout=5):
        self.function = function
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, *args):
        # Here the code returning the correct thing.

def cache_hits(max_hits=10, timeout=5):
    def _cache(function):
        return Cache(function,max_hits,timeout)
    return _cache

@cache_hits()
def double(x):
    return x * 2

@cache_hits(max_hits=100, timeout=50)
def double(x):
    return x * 2

PS. If the class Cache has no other methods besides __init__ and __call__, you can probably move all the code inside the _cache function and eliminate Cache altogether.

Solution 3

I've learned a lot from this question, thanks all. Isn't the answer just to put empty brackets on the first @Cache? Then you can move the function parameter to __call__.

class Cache(object):
    def __init__(self, max_hits=10, timeout=5):
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, function, *args):
        # Here the code returning the correct thing.

@Cache()
def double(x):
    return x * 2

@Cache(max_hits=100, timeout=50)
def double(x):
    return x * 2

Although I think this approach is simpler and more concise:

def cache(max_hits=10, timeout=5):
    def caching_decorator(fn):
        def decorated_fn(*args ,**kwargs):
            # Here the code returning the correct thing.
        return decorated_fn
    return decorator

If you forget the parentheses when using the decorator, unfortunately you still don't get an error until runtime, as the outer decorator parameters are passed the function you're trying to decorate. Then at runtime the inner decorator complains:

TypeError: caching_decorator() takes exactly 1 argument (0 given).

However you can catch this, if you know your decorator's parameters are never going to be a callable:

def cache(max_hits=10, timeout=5):
    assert not callable(max_hits), "@cache passed a callable - did you forget to parenthesize?"
    def caching_decorator(fn):
        def decorated_fn(*args ,**kwargs):
            # Here the code returning the correct thing.
        return decorated_fn
    return decorator

If you now try:

@cache
def some_method()
    pass

You get an AssertionError on declaration.

On a total tangent, I came across this post looking for decorators that decorate classes, rather than classes that decorate. In case anyone else does too, this question is useful.

Solution 4

I'd rather to include the wrapper inside the class's __call__ method:

UPDATE: This method has been tested in python 3.6, so I'm not sure about the higher or earlier versions.

class Cache:
    def __init__(self, max_hits=10, timeout=5):
        # Remove function from here and add it to the __call__
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, function):
        def wrapper(*args):
            value = function(*args)
            # saving to cache codes
            return value
        return wrapper

@Cache()
def double(x):
    return x * 2

@Cache(max_hits=100, timeout=50)
def double(x):
    return x * 2

Solution 5

You can use a classmethod as a factory method, this should handle all the use cases (with or without parenthesis).

import functools
class Cache():
    def __init__(self, function):
        functools.update_wrapper(self, function)
        self.function = function
        self.max_hits = self.__class__.max_hits
        self.timeout = self.__class__.timeout
        self.cache = {}

    def __call__(self, *args):
        # Here the code returning the correct thing.
    
    @classmethod
    def Cache_dec(cls, _func = None, *, max_hits=10, timeout=5):
        cls.max_hits = max_hits
        cls.timeout = timeout
        if _func is not None: #when decorator is passed parenthesis
            return cls(_func)
        else:
            return cls    #when decorator is passed without parenthesis
       

@Cache.Cache_dec
def double(x):
    return x * 2

@Cache.Cache_dec()
def double(x):
    return x * 2

@Cache.Cache_dec(timeout=50)
def double(x):
    return x * 2

@Cache.Cache_dec(max_hits=100)
def double(x):
    return x * 2

@Cache.Cache_dec(max_hits=100, timeout=50)
def double(x):
    return x * 2
Share:
23,028
Dachmt
Author by

Dachmt

Living the dream in San Francisco, CA, from south west of France, loving surfing, outdoors and developing websites.

Updated on February 12, 2022

Comments

  • Dachmt
    Dachmt over 2 years

    I'm trying to pass optional arguments to my class decorator in python. Below the code I currently have:

    class Cache(object):
        def __init__(self, function, max_hits=10, timeout=5):
            self.function = function
            self.max_hits = max_hits
            self.timeout = timeout
            self.cache = {}
    
        def __call__(self, *args):
            # Here the code returning the correct thing.
    
    
    @Cache
    def double(x):
        return x * 2
    
    @Cache(max_hits=100, timeout=50)
    def double(x):
        return x * 2
    

    The second decorator with arguments to overwrite the default one (max_hits=10, timeout=5 in my __init__ function), is not working and I got the exception TypeError: __init__() takes at least 2 arguments (3 given). I tried many solutions and read articles about it, but here I still can't make it work.

    Any idea to resolve this? Thanks!