本文主要分析了 Flask 的上下文机制,包括:application contextrequest context

Local 的实现

在 Flask 中,视图函数处理 http 请求时,需要知道具体的请求信息,比如URL、请求方法、参数等,以及应用信息。Flask 将这些信息封装成上下文对象,类似于全局变量,不同的是:它们是动态的,每个线程都有自己独立的上下文信息,不会相互干扰。
那么是如何实现这种机制的?在 Python 多线程编程中,有个 threading.local 类,可以实现访问某个变量时只有该线程自己能看到,实现的原理是,这个对象保存一个字典,以线程 ID 为键,读取该对象时,它动态地查询当前线程 ID 对应的数据。Flask 使用 Local 类实现了类似的效果,下面是它的实现代码。

class Local(object):
    __slots__ = ('__storage__', '__ident_func__')

    def __init__(self):
        object.__setattr__(self, '__storage__', {})
        object.__setattr__(self, '__ident_func__', get_ident)

    def __call__(self, proxy):
        """Create a proxy for a name."""
        return LocalProxy(self, proxy)

    def __release_local__(self):
        self.__storage__.pop(self.__ident_func__(), None)

    def __getattr__(self, name):
        try:
            return self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}

Local 类的构造函数中,初始化了两个属性,一个是 __storage__,这个属性是一个字典,用来保存上下文数据;另一个属性是 __ident_func__,用来获取当前线程的 ID。从 __getattr__ 方法可以看出,我们在通过点号运算获取数据时,根据当前线程 ID: self.__ident_func__(),访问嵌套字典里定义的 name-value 键值对,从而获取保存的上下文数据。存储数据的方法在 __setattr__ 中定义,将数据保存在 storage[ident][name] 字典中,外层键为线程 ID。通过 Local 类,可以实现数据的访问隔离。 另外,__call__ 操作会创建一个 LocalProxy 对象,LocalProxy 会在下面讲到。

LocalStack 的实现

Flask 中,Local 是怎样与上下文关联的呢,这就得研究LocalStackLocalProxy的实现。首先我们来看 LocalStack 的定义。

class LocalStack(object):
    def __init__(self):
        self._local = Local()

    def __call__(self):
        def _lookup():
            rv = self.top
            if rv is None:
                raise RuntimeError('object unbound')
            return rv
        return LocalProxy(_lookup)

    def push(self, obj):
        """Pushes a new item to the stack"""
        rv = getattr(self._local, 'stack', None)
        if rv is None:
            self._local.stack = rv = []
        rv.append(obj)
        return rv

    def pop(self):
        """Removes the topmost item from the stack, will return the
        old value or `None` if the stack was already empty.
        """
        stack = getattr(self._local, 'stack', None)
        if stack is None:
            return None
        elif len(stack) == 1:
            release_local(self._local)
            return stack[-1]
        else:
            return stack.pop()

LocalStack 的属性 _local 是一个 Local 对象,在 push 方法中,将数据保存在 self._local.stack,即属性Local__storage__[ident][stack] 里。实际上 LocalStack 实现了线程隔离的栈访问,它提供了操作栈结构的 pushpoptop 方法。

LocalProxy 的实现

LocalProxyLocal 对象的代理,负责把所有对自己的操作转发给内部的 Local 对象。

class LocalProxy(object):
    __slots__ = ('__local', '__dict__', '__name__')

    def __init__(self, local, name=None):
        object.__setattr__(self, '_LocalProxy__local', local)
        object.__setattr__(self, '__name__', name)

    def _get_current_object(self):
        if not hasattr(self.__local, '__release_local__'):
            return self.__local()
        try:
            return getattr(self.__local, self.__name__)
        except AttributeError:
            raise RuntimeError('no object bound to %s' % self.__name__)

    def __getattr__(self, name):
        if name == '__members__':
            return dir(self._get_current_object())
        return getattr(self._get_current_object(), name)

    def __setitem__(self, key, value):
        self._get_current_object()[key] = value

构造函数中初始化的 _LocalProxy__local属性,后面可以通过 self.__local 来访问。_get_current_object 方法用来获取当前线程对应的对象,当 __local 值为 Local 对象时,_get_current_object 返回的是__storage__[self.__ident_func__()][name]

上下文的定义

接下来研究上下文在 Flask 中是怎样定义的?

def _lookup_req_object(name):
    top = _request_ctx_stack.top
    if top is None:
        raise RuntimeError(_request_ctx_err_msg)
    return getattr(top, name)


def _lookup_app_object(name):
    top = _app_ctx_stack.top
    if top is None:
        raise RuntimeError(_app_ctx_err_msg)
    return getattr(top, name)


def _find_app():
    top = _app_ctx_stack.top
    if top is None:
        raise RuntimeError(_app_ctx_err_msg)
    return top.app


# context locals
_request_ctx_stack = LocalStack()
_app_ctx_stack = LocalStack()
current_app = LocalProxy(_find_app)
request = LocalProxy(partial(_lookup_req_object, 'request'))
session = LocalProxy(partial(_lookup_req_object, 'session'))
g = LocalProxy(partial(_lookup_app_object, 'g'))

request 的值是将 _request_ctx_stack 中的栈顶元素弹出,根据 Local.[self.__ident_func__()]['stack'][-1].request 获取。那么请求上下文是在什么时候入栈的呢? 在 wsgi_app 方法中,将对应的请求信息入栈,每次请求过来的时候,就会调用该代码。

ctx = self.request_context(environ)
ctx.push()

具体是怎样做的,让我们继续查看内部实现。

class RequestContext(object):
    def __init__(self, app, environ, request=None):
        self.app = app
        if request is None:
            request = app.request_class(environ)
        self.request = request
        self.url_adapter = app.create_url_adapter(self.request)
        self.session = None
        self._implicit_app_ctx_stack = []
        self.match_request()

    def push(self):
        top = _request_ctx_stack.top
        if top is not None and top.preserved:
            top.pop(top._preserved_exc)

        app_ctx = _app_ctx_stack.top
        if app_ctx is None or app_ctx.app != self.app:
            app_ctx = self.app.app_context()
            app_ctx.push()
            self._implicit_app_ctx_stack.append(app_ctx)
        else:
            self._implicit_app_ctx_stack.append(None)

        if hasattr(sys, 'exc_clear'):
            sys.exc_clear()

        _request_ctx_stack.push(self)

        self.session = self.app.open_session(self.request)
        if self.session is None:
            self.session = self.app.make_null_session()

RequestContext 保存了当前请求的信息,比如 request 对象和 app 对象。在前面的代码 中,就是获取了当前请求上下文的 request 对象

request = LocalProxy(partial(_lookup_req_object, 'request'))

push 方法把该请求的对应的 AppContextRequestContext 分别压入到 _app_ctx_stack_request_ctx_stack
到这里,上下文的实现就比较清晰了:每次有请求过来的时候,flask 会先创建当前线程或者进程需要处理的两个上下文对象,把它们保存到隔离的栈里面,这样视图函数进行处理的时候就能直接从栈上获取这些信息。