When to use classes in Python? When you repeat similar sets of functions

September 2023 ∙ 11 minute read ∙

Are you having trouble figuring out when to use classes or how to organize them?

Have you repeatedly searched for "when to use classes in Python", read all the articles and watched all the talks, and still don't know whether you should be using classes in any given situation?

Have you read discussions about it that for all you know may be right, but they're so academic you can't parse the jargon?

Have you read articles that all treat the "obvious" cases, leaving you with no clear answer when you try to apply them to your own code?


My experience is that, unfortunately, the best way to learn this is to look at lots of examples.

Most guidelines tend to either be too vague if you don't already know enough about the subject, or too specific and saying things you already know.

This is one of those things that once you get it seems obvious and intuitive, but it's not, and is quite difficult to explain properly.


So, instead of prescribing a general approach, let's look at:

  • one specific case where you may want to use classes
  • examples from real-world code
  • some considerations you should keep in mind
Contents

The heuristic #

If you repeat similar sets of functions, consider grouping them in a class.

That's it.

In its most basic form, a class is when you group data with functions that operate on that data; sometimes, there is no data, but it can still be useful to group the functions into an abstract object that exists only to make things easier to use / understand.

Depending on whether you choose which class to use at runtime, this is sometimes called the strategy pattern.

Note

As Wikipedia puts it, "A heuristic is a practical way to solve a problem. It is better than chance, but does not always work. A person develops a heuristic by using intelligence, experience, and common sense."

So, this is not the correct thing to do all the time, or even most of the time.

Instead, I hope that this and other heuristics can help build the right intuition for people on their way from "I know the class syntax, now what?" to "proper" object-oriented design.

Example: Retrievers #

My feed reader library retrieves and stores web feeds (Atom, RSS and so on).

Usually, feeds come from the internet, but you can also use local files. The parsers for various formats don't really care where a feed is coming from, so they always take an open file as input.

reader supports conditional requests – that is, only retrieve a feed if it changed. To do this, it stores the ETag HTTP header from a response, and passes it back as the If-None-Match header of the next request; if nothing changed, the server can respond with 304 Not Modified instead of sending back the full content.

Let's have a look at how the code to retrieve feeds evolved over time; this version omits a few details, but it will end up with a structure similar to that of the full version. In the beginning, there was a function – URL and old ETag in, file and new ETag out:

def retrieve(url, etag=None):
    if any(url.startswith(p) for p in ('http://', 'https://')):
        headers = {}
        if etag:
            headers['If-None-Match'] = etag
        response = requests.get(url, headers=headers, stream=True)
        response.raise_for_status()
        if response.status_code == 304:
            response.close()
            return None, etag
        etag = response.headers.get('ETag', etag)
        response.raw.decode_content = True
        return response.raw, etag

    # fall back to file
    path = extract_path(url)
    return open(path, 'rb'), None

We use Requests to get HTTP URLs, and return the underlying file-like object.1

For local files, we suport both bare paths and file URIs; for the latter, we do a bit of validation – file:feed and file://localhost/feed are OK, but file://invalid/feed and unknown:feed2 are not:

def extract_path(url):
    url_parsed = urllib.parse.urlparse(url)
    if url_parsed.scheme == 'file':
        if url_parsed.netloc not in ('', 'localhost'):
            raise ValueError("unknown authority for file URI")
        return urllib.request.url2pathname(url_parsed.path)
    if url_parsed.scheme:
        raise ValueError("unknown scheme for file URI")
    # no scheme, treat as a path
    return url

Problem: can't add new feed sources #

One of reader's goals is to be extensible. For example, it should be possible to add new feed sources like an FTP server (ftp://...) or Twitter without changing reader code; however, our current implementation makes it hard to do so.

We can fix this by extracting retrieval logic into separate functions, one per protocol:

def http_retriever(url, etag):
    headers = {}
    # ...
    return response.raw, etag

def file_retriever(url, etag):
    path = extract_path(url)
    return open(path, 'rb'), None

...and then routing to the right one depending on the URL prefix:

# sorted by key length (longest first)
RETRIEVERS = {
    'https://': http_retriever,
    'http://': http_retriever,
    # fall back to file
    '': file_retriever,
}

def get_retriever(url):
    for prefix, retriever in RETRIEVERS.items():
        if url.lower().startswith(prefix.lower()):
            return retriever
    raise ValueError("no retriever for URL")

def retrieve(url, etag=None):
    retriever = get_retriever(url)
    return retriever(url, etag)

Now, plugins can register retrievers by adding them to RETRIEVERS (in practice, there's a method for that, so users don't need to care about it staying sorted).

Problem: can't validate URLs until retrieving them #

To add a feed, you call add_feed() with the feed URL.

But what if you pass an invalid URL? The feed gets stored in the database, and you get an "unknown scheme for file URI" error on the next update. However, this can be confusing – a good API should signal errors near the action that triggered them. This means add_feed() needs to validate the URL without actually retrieving it.

For HTTP, Requests can do the validation for us; for files, we can call extract_path() and ignore the result. Of course, we should select the appropriate logic in the same way we select retrievers, otherwise we're back where we started.

Now, there's more than one way of doing this. We could keep a separate validator registry, but that may accidentally become out of sync with the retriever one.

URL_VALIDATORS = {
    'https://': http_url_validator,
    'http://': http_url_validator,
    '': file_url_validator,
}

Or, we could keep a (retriever, validator) pair in the retriever registry. This is better, but it's not all that readable (what if need to add a third thing?); also, it makes customizing behavior that affects both the retriever and validator harder.

RETRIEVERS = {
    'https://': (http_retriever, http_url_validator),
    'http://': (http_retriever, http_url_validator),
    '': (file_retriever, file_url_validator),
}

Better yet, we can use a class to make the grouping explicit:

class HTTPRetriever:

    def retrieve(self, url, etag):
        headers = {}
        # ...
        return response.raw, etag

    def validate_url(self, url):
        session = requests.Session()
        session.get_adapter(url)
        session.prepare_request(requests.Request('GET', url))

class FileRetriever:

    def retrieve(self, url, etag):
        path = extract_path(url)
        return open(path, 'rb'), None

    def validate_url(self, url):
        extract_path(url)

We then instantiate them, and update retrieve() to call the methods:

http_retriever = HTTPRetriever()
file_retriever = FileRetriever()
def retrieve(url, etag=None):
    retriever = get_retriever(url)
    return retriever.retrieve(url, etag)

validate_url() works just the same:

def validate_url(url):
    retriever = get_retriever(url)
    retriever.validate_url(url)

And there you have it – if you repeat similar sets of functions, consider grouping them in a class.

Not just functions, attributes too #

Say you want to update feeds in parallel, using multiple threads.

Retrieving feeds is mostly waiting around for I/O, so it will benefit the most from it. Parsing, on the other hand, is pure Python, CPU bound code, so threads won't help due to the global interpreter lock.

However, because we're streaming the reponse body, I/O is not done when the retriever returns the file, but when the parser finishes reading it.3 We can move all the (network) I/O in retrieve() by reading the response into a temporary file and returning it instead.

We'll allow any retriever to opt into this behavior by using a class attribute:

class HTTPRetriever:
    slow_to_read = True
class FileRetriever:
    slow_to_read = False

If a retriever is slow to read, retrieve() does the swap:

def retrieve(url, etag=None):
    retriever = get_retriever(url)
    file, etag = retriever.retrieve(url, etag)

    if file and retriever.slow_to_read:
        temp = tempfile.TemporaryFile()
        shutil.copyfileobj(file, temp)
        file.close()
        temp.seek(0)
        file = temp

    return file, etag
Liking this so far? Here's another article you might like:

Example: Flask's tagged JSON #

The Flask web framework provides an extendable compact representation for non-standard JSON types called tagged JSON (code). The serializer class delegates most conversion work to methods of various JSONTag subclasses (one per supported type):

  • check() checks if a Python value should be tagged by that tag
  • tag() converts it to tagged JSON
  • to_python() converts a JSON value back to Python (the serializer uses the key tag attribute to find the correct tag)

Interestingly, tag instances have an attribute pointing back to the serializer, likely to allow recursion – when (un)packing a possibly nested collection, you need to recursively (un)pack its values. Passing the serializer to each method would have also worked, but when your functions take the same arguments...

Formalizing this #

OK, the retriever code works. But, how should you communicate to others (readers, implementers, interpreters, type checkers) that an HTTPRetriever is the same kind of thing as a FileRetriever, and as anything else that can go in RETRIEVERS?

Duck typing #

Here's the definition of duck typing:

A programming style which does not look at an object's type to determine if it has the right interface; instead, the method or attribute is simply called or used ("If it looks like a duck and quacks like a duck, it must be a duck.") [...]

This is what we're doing now! If it retrieves like a retriever and validates URLs like a retriever, then it's a retriever.

You see this all the time in Python. For example, json.dump() takes a file-like object; now, the full text file interface has lots methods and attributes, but dump() only cares about write(), and will accept any object implementing it:

>>> class MyFile:
...     def write(self, s):
...         print(f"writing: {s}")
...
>>> f = MyFile()
>>> json.dump({'one': 1}, f)
writing: {
writing: "one"
writing: :
writing: 1
writing: }

The main way to communicate this is through documentation:

Serialize obj [...] to fp (a .write()-supporting file-like object)

Inheritance #

Nevertheless, you may want to be more explicit about the relationships between types. The easiest option is to use a base class, and require retrievers to inherit from it.

class Retriever:
    slow_to_read = False

    def retrieve(self, url, etag):
        raise NotImplementedError

    def validate_url(self, url):
        raise NotImplementedError

This allows you to check you the type with isinstance(), provide default methods and attributes, and will help type checkers and autocompletion, at the expense of forcing a dependency on the base class.

>>> class MyRetriever(Retriever): pass
>>> retriever = MyRetriever()
>>> retriever.slow_to_read
False
>>> isinstance(retriever, Retriever)
True

What it won't do is check subclasses actually define the methods:

>>> retriever.validate_url('myurl')
Traceback (most recent call last):
  ...
NotImplementedError

Abstract base classes #

This is where abstract base classes come in. The decorators in the abc module allow defining abstract methods that must be overriden:

class Retriever(ABC):

    @abstractproperty
    def slow_to_read(self):
        return False

    @abstractmethod
    def retrieve(self, url, etag):
        raise NotImplementedError

    @abstractmethod
    def validate_url(self, url):
        raise NotImplementedError

This is checked at runtime (but only that methods and attributes are present, not their signatures or types):

>>> class MyRetriever(Retriever): pass
>>> MyRetriever()
Traceback (most recent call last):
  ...
TypeError: Can't instantiate abstract class MyRetriever with abstract methods retrieve, slow_to_read, validate_url
>>> class MyRetriever(Retriever):
...     slow_to_read = False
...     def retrieve(self, url, etag): ...
...     def validate_url(self, url): ...
...
>>> MyRetriever()
<__main__.MyRetriever object at 0x1037aac50>

Tip

You can also use ABCs to register arbitrary types as "virtual subclasses"; this allows them to pass isinstance() checks without inheritance, but won't check for required methods:

>>> class MyRetriever: pass
>>> Retriever.register(MyRetriever)
<class '__main__.MyRetriever'>
>>> isinstance(MyRetriever(), Retriever)
True

Protocols #

Finally, we have protocols, aka structural subtyping, aka static duck typing. Introduced in PEP 544, they go in the opposite direction – what if instead declaring what the type of something is, we declare what methods it has to have to be of a specific type?

You define a protocol by inheriting typing.Protocol:

class Retriever(Protocol):

    @property
    def slow_to_read(self) -> bool:
        ...

    def retrieve(self, url: str, etag: str | None) -> tuple[IO[bytes] | None, str | None]:
        ...

    def validate_url(self, url: str) -> None:
        ...

...and then use it in type annotations:

def mount_retriever(prefix: str, retriever: Retriever) -> None:
    raise NotImplementedError

Some other code (not necessarily yours, not necessarily aware the protocol even exists) defines an implementation:

class MyRetriever:
    slow_to_read = False

    def validate_url(self):
        pass

...and then uses it with annotated code:

mount_retriever('my', MyRetriever())

A type checker like mypy will check if the provided instance conforms to the protocol – not only that methods exist, but that their signatures are correct too – all without the implementation having to declare anything.

$ mypy myproto.py
myproto.py:11: error: Argument 2 to "mount_retriever" has incompatible type "MyRetriever"; expected "Retriever"  [arg-type]
myproto.py:11: note: "MyRetriever" is missing following "Retriever" protocol member:
myproto.py:11: note:     retrieve
myproto.py:11: note: Following member(s) of "MyRetriever" have conflicts:
myproto.py:11: note:     Expected:
myproto.py:11: note:         def validate_url(self, url: str) -> None
myproto.py:11: note:     Got:
myproto.py:11: note:         def validate_url(self) -> Any
Found 1 error in 1 file (checked 1 source file)

Tip

If you decorate your protocol with runtime_checkable, you can use it in isinstance() checks, but like ABCs, it only checks methods are present.

Counter-example: modules #

If a class has no state and you don't need inheritance, you can use a module instead:

# module.py

slow_to_read = False

def retrieve(url, etag):
    raise NotImplementedError

def validate_url(url):
    raise NotImplementedError

From a duck typing perspective, this is a valid retriever, since it has all the expected methods and attributes. So much so, that it's also compatible with protocols:

import module

mount_retriever('mod', module)
$ mypy module.py
Success: no issues found in 1 source file

I tried to keep the retriever example stateless, but real world classes rarely are (it may be immutable state, but it's state nonetheless). Also, you're limited to exactly one implementation per module, which is usually too much like Java for my taste.

Try it out #

If you're doing something and you think you need a class, do it and see how it looks. If you think it's better, keep it, otherwise, revert the change. You can always switch in either direction later.

If you got it right the first time, great! If not, by having to fix it you'll learn something, and next time you'll know better.

Also, don't beat yourself up.

Sure, there are nice libraries out there that use classes in just the right way, after spending lots of time to find the right abstraction. But abstraction is difficult and time consuming, and in everyday code good enough is just that – good enough – you don't need to go to the extreme.

Learned something new today? Share this with others, it really helps!

If you've made it this far, you might like:
  1. This code has a potential bug: if we were using a persistent session instead of a transient one, the connection would never be released, since we're not closing the response after we're done with it. In the actual code, we're doing both, but the only way do so reliably is to return a context manager; I omitted this because it doesn't add anything to our discussion about classes. [return]

  2. We're handling unknown URI schemes here because bare paths don't have a scheme, so anything that didn't match a known scheme must be a bare path. Also, on Windows (not supported yet), the drive letter in a path like c:\feed.xml is indistinguishable from a scheme. [return]

  3. Unless the response is small enough to fit in the TCP receive buffer. [return]


This is part of a series: