beancount.utils

通用工具包和函数。

beancount.utils.bisect_key

一个支持自定义键函数的 bisect 版本,类似于排序函数。

beancount.utils.bisect_key.bisect_left_with_key(sequence, value, key=None)

在已排序列表中查找小于给定值的最后一个元素。

参数:
  • sequence – 一个已排序的元素序列。

  • value – 要搜索的值。

  • key – 可选函数,用于从 sequence 的元素中提取值。

返回:
  • 返回索引。可能返回 None。

源代码位于 beancount/utils/bisect_key.py
def bisect_left_with_key(sequence, value, key=None):
    """Find the last element before the given value in a sorted list.

    Args:
      sequence: A sorted sequence of elements.
      value: The value to search for.
      key: An optional function used to extract the value from the elements of
        sequence.
    Returns:
      Return the index. May return None.
    """
    if key is None:
        key = lambda x: x  # Identity.

    lo = 0
    hi = len(sequence)

    while lo < hi:
        mid = (lo + hi) // 2
        if key(sequence[mid]) < value:
            lo = mid + 1
        else:
            hi = mid
    return lo

beancount.utils.bisect_key.bisect_right_with_key(a, x, key, lo=0, hi=None)

类似于 bisect.bisect_right,但增加了键查找参数。

参数:
  • a – 要搜索的列表。

  • x – 要查找的元素。

  • key – 用于从列表中提取值的函数。

  • lo – 最小的搜索索引。

  • hi – 最大的搜索索引。

返回:
  • 与 bisect.bisect_right 类似,来自列表 'a' 的一个元素。

源代码位于 beancount/utils/bisect_key.py
def bisect_right_with_key(a, x, key, lo=0, hi=None):
    """Like bisect.bisect_right, but with a key lookup parameter.

    Args:
      a: The list to search in.
      x: The element to search for.
      key: A function, to extract the value from the list.
      lo: The smallest index to search.
      hi: The largest index to search.
    Returns:
      As in bisect.bisect_right, an element from list 'a'.
    """
    # pylint: disable=invalid-name
    if lo < 0:
        raise ValueError('lo must be non-negative')
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo+hi)//2
        if x < key(a[mid]):
            hi = mid
        else:
            lo = mid+1
    return lo

beancount.utils.date_utils

从多种格式中解析日期。

beancount.utils.date_utils.intimezone(tz_value)

临时重置 TZ 的值。

此功能用于测试。

参数:
  • tz_value (str) – 在此上下文期间要设置的 TZ 值。

返回:
  • 一个处于指定时区区域设置的上下文管理器。

源代码位于 beancount/utils/date_utils.py
@contextlib.contextmanager
def intimezone(tz_value: str):
    """Temporarily reset the value of TZ.

    This is used for testing.

    Args:
      tz_value: The value of TZ to set for the duration of this context.
    Returns:
      A contextmanager in the given timezone locale.
    """
    tz_old = os.environ.get('TZ', None)
    os.environ['TZ'] = tz_value
    time.tzset()
    try:
        yield
    finally:
        if tz_old is None:
            del os.environ['TZ']
        else:
            os.environ['TZ'] = tz_old
        time.tzset()

beancount.utils.date_utils.iter_dates(start_date, end_date)

生成 'start_date' 和 'end_date' 之间的所有日期。

参数:
  • start_date – 一个 datetime.date 实例。

  • end_date – 一个 datetime.date 实例。

生成:datetime.date 实例。

源代码位于 beancount/utils/date_utils.py
def iter_dates(start_date, end_date):
    """Yield all the dates between 'start_date' and 'end_date'.

    Args:
      start_date: An instance of datetime.date.
      end_date: An instance of datetime.date.
    Yields:
      Instances of datetime.date.
    """
    oneday = datetime.timedelta(days=1)
    date = start_date
    while date < end_date:
        yield date
        date += oneday

beancount.utils.date_utils.next_month(date)

计算给定日期之后下个月的第一天。

参数:
  • date – 一个 datetime.date 实例。

返回:
  • 一个 datetime.date 实例,表示 'date' 所在月份的下一个月的第一天。

源代码位于 beancount/utils/date_utils.py
def next_month(date):
    """Compute the date at the beginning of the following month from the given date.

    Args:
      date: A datetime.date instance.
    Returns:
      A datetime.date instance, the first day of the month following 'date'.
    """
    # Compute the date at the beginning of the following month.
    year = date.year
    month = date.month + 1
    if date.month == 12:
        year += 1
        month = 1
    return datetime.date(year, month, 1)

beancount.utils.date_utils.render_ofx_date(dtime)

将 datetime 转换为 OFX 格式。

参数:
  • dtime – 一个 datetime.datetime 实例。

返回:
  • 一个字符串,以毫秒为单位渲染。

源代码位于 beancount/utils/date_utils.py
def render_ofx_date(dtime):
    """Render a datetime to the OFX format.

    Args:
      dtime: A datetime.datetime instance.
    Returns:
      A string, rendered to milliseconds.
    """
    return '{}.{:03d}'.format(dtime.strftime('%Y%m%d%H%M%S'),
                              int(dtime.microsecond / 1000))

beancount.utils.defdict

一个 collections.defaultdict 实例,其工厂函数接受键作为参数。

注意:这实际上应该成为 Python 本身的增强功能。我最终应该添加这个功能。

beancount.utils.defdict.DefaultDictWithKey (defaultdict)

一个 defaultdict 的版本,其工厂函数接受键作为参数。注意:collections.defaultdict 如果能直接支持此功能将得到改进,这是一种常见需求。

beancount.utils.defdict.ImmutableDictWithDefault (dict)

一个不可变字典,对缺失的键返回默认值。

这与 defaultdict 不同,因为它在获取缺失值时不会插入默认值,并且此外,set 方法被禁用,以防止在构造后发生突变。

beancount.utils.defdict.ImmutableDictWithDefault.__setitem__(self, key, value) 特殊

禁止以常规方式修改字典。

源代码位于 beancount/utils/defdict.py
def __setitem__(self, key, value):
    """Disallow mutating the dict in the usual way."""
    raise NotImplementedError

beancount.utils.defdict.ImmutableDictWithDefault.get(self, key, _=None)

如果键存在于字典中,则返回该键对应的值;否则返回默认值。

源代码位于 beancount/utils/defdict.py
def get(self, key, _=None):
    return self.__getitem__(key)  # pylint: disable=unnecessary-dunder-call

beancount.utils.encryption

支持加密测试。

beancount.utils.encryption.is_encrypted_file(filename)

如果给定的文件名对应一个加密文件,则返回 True。

参数:
  • filename – 一个路径字符串。

返回:
  • 布尔值,若文件为加密文件则返回 True。

源代码位于 beancount/utils/encryption.py
def is_encrypted_file(filename):
    """Return true if the given filename contains an encrypted file.

    Args:
      filename: A path string.
    Returns:
      A boolean, true if the file contains an encrypted file.
    """
    _, ext = path.splitext(filename)
    if ext == '.gpg':
        return True
    if ext == '.asc':
        with open(filename) as encfile:
            head = encfile.read(1024)
            if re.search('--BEGIN PGP MESSAGE--', head):
                return True
    return False

beancount.utils.encryption.is_gpg_installed()

如果已安装 GPG 1.4.x 或 2.x 版本(我们使用并支持的版本),则返回 True。

源代码位于 beancount/utils/encryption.py
def is_gpg_installed():
    """Return true if GPG 1.4.x or 2.x are installed, which is what we use and support."""
    try:
        pipe = subprocess.Popen(['gpg', '--version'], shell=0,
                                stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = pipe.communicate()
        version_text = out.decode('utf8')
        return pipe.returncode == 0 and re.match(r'gpg \(GnuPG\) (1\.4|2)\.', version_text)
    except OSError:
        return False

beancount.utils.encryption.read_encrypted_file(filename)

在不使用临时存储的情况下解密并读取加密文件。

参数:
  • filename – 一个字符串,表示加密文件的路径。

返回:
  • 一个字符串,即文件的内容。

异常:
  • OSError – 如果无法正确解密文件。

源代码位于 beancount/utils/encryption.py
def read_encrypted_file(filename):
    """Decrypt and read an encrypted file without temporary storage.

    Args:
      filename: A string, the path to the encrypted file.
    Returns:
      A string, the contents of the file.
    Raises:
      OSError: If we could not properly decrypt the file.
    """
    command = ['gpg', '--batch', '--decrypt', path.realpath(filename)]
    pipe = subprocess.Popen(command,
                            shell=False,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
    contents, errors = pipe.communicate()
    if pipe.returncode != 0:
        raise OSError("Could not decrypt file ({}): {}".format(pipe.returncode,
                                                               errors.decode('utf8')))
    return contents.decode('utf-8')

beancount.utils.file_utils

文件工具。

beancount.utils.file_utils.chdir(directory)

临时切换到指定目录。

参数:
  • directory – 要切换到的目录。

返回:
  • 一个上下文管理器,在执行后恢复当前工作目录。

源代码位于 beancount/utils/file_utils.py
@contextlib.contextmanager
def chdir(directory):
    """Temporarily chdir to the given directory.

    Args:
      directory: The directory to switch do.
    Returns:
      A context manager which restores the cwd after running.
    """
    cwd = os.getcwd()
    os.chdir(directory)
    try:
        yield cwd
    finally:
        os.chdir(cwd)

beancount.utils.file_utils.find_files(fords, ignore_dirs=('.hg', '.svn', '.git'), ignore_files=('.DS_Store',))

稳定地枚举指定目录下的所有文件。

无效的文件或目录名称将被记录到错误日志中。

参数:
  • fords – 一个字符串列表,包含文件或目录名称。

  • ignore_dirs – 一个字符串列表,指定要忽略的文件或目录名称。

生成:来自指定根目录的完整文件名字符串。

源代码位于 beancount/utils/file_utils.py
def find_files(fords,
               ignore_dirs=('.hg', '.svn', '.git'),
               ignore_files=('.DS_Store',)):
    """Enumerate the files under the given directories, stably.

    Invalid file or directory names will be logged to the error log.

    Args:
      fords: A list of strings, file or directory names.
      ignore_dirs: A list of strings, filenames or directories to be ignored.
    Yields:
      Strings, full filenames from the given roots.
    """
    if isinstance(fords, str):
        fords = [fords]
    assert isinstance(fords, (list, tuple))
    for ford in fords:
        if path.isdir(ford):
            for root, dirs, filenames in os.walk(ford):
                dirs[:] = sorted(dirname for dirname in dirs if dirname not in ignore_dirs)
                for filename in sorted(filenames):
                    if filename in ignore_files:
                        continue
                    yield path.join(root, filename)
        elif path.isfile(ford) or path.islink(ford):
            yield ford
        elif not path.exists(ford):
            logging.error("File or directory '{}' does not exist.".format(ford))

beancount.utils.file_utils.guess_file_format(filename, default=None)

根据文件名猜测文件格式。

参数:
  • filename – 一个字符串,表示文件名。此参数可以为 None。

返回:
  • 一个字符串,表示格式的扩展名,不包含前导点号。

源代码位于 beancount/utils/file_utils.py
def guess_file_format(filename, default=None):
    """Guess the file format from the filename.

    Args:
      filename: A string, the name of the file. This can be None.
    Returns:
      A string, the extension of the format, without a leading period.
    """
    if filename:
        if filename.endswith('.txt') or filename.endswith('.text'):
            format = 'text'
        elif filename.endswith('.csv'):
            format = 'csv'
        elif filename.endswith('.html') or filename.endswith('.xhtml'):
            format = 'html'
        else:
            format = default
    else:
        format = default
    return format

beancount.utils.file_utils.path_greedy_split(filename)

分割路径,返回尽可能长的扩展名。

参数:
  • filename – 一个字符串,表示要分割的文件名。

返回:
  • 一个包含基名和扩展名的元组(扩展名包含前导点号)。

源代码位于 beancount/utils/file_utils.py
def path_greedy_split(filename):
    """Split a path, returning the longest possible extension.

    Args:
      filename: A string, the filename to split.
    Returns:
      A pair of basename, extension (which includes the leading period).
    """
    basename = path.basename(filename)
    index = basename.find('.')
    if index == -1:
        extension = None
    else:
        extension = basename[index:]
        basename = basename[:index]
    return (path.join(path.dirname(filename), basename), extension)

beancount.utils.file_utils.touch_file(filename, *otherfiles)

触碰一个文件,并等待其时间戳发生更改。

参数:
  • filename – 一个字符串路径,表示要触碰的文件名。

  • otherfiles – 其他文件的列表,用于确保该文件的时间戳晚于这些文件。

源代码位于 beancount/utils/file_utils.py
def touch_file(filename, *otherfiles):
    """Touch a file and wait until its timestamp has been changed.

    Args:
      filename: A string path, the name of the file to touch.
      otherfiles: A list of other files to ensure the timestamp is beyond of.
    """
    # Note: You could set os.stat_float_times() but then the main function would
    # have to set that up as well. It doesn't help so much, however, since
    # filesystems tend to have low resolutions, e.g. one second.
    orig_mtime_ns = max(os.stat(minfile).st_mtime_ns
                        for minfile in (filename,) + otherfiles)
    delay_secs = 0.05
    while True:
        with open(filename, 'a'):
            os.utime(filename)
        time.sleep(delay_secs)
        new_stat = os.stat(filename)
        if new_stat.st_mtime_ns > orig_mtime_ns:
            break

beancount.utils.import_utils

用于以编程方式导入符号的工具函数。

beancount.utils.import_utils.import_symbol(dotted_name)

导入任意模块中的符号。

参数:
  • dotted_name – 一个点分路径,指向某个符号。

返回:
  • 由给定名称引用的对象。

异常:
  • ImportError – 如果模块无法导入。

  • AttributeError – 如果在模块中找不到该符号。

beancount/utils/import_utils.py 中的源代码
def import_symbol(dotted_name):
    """Import a symbol in an arbitrary module.

    Args:
      dotted_name: A dotted path to a symbol.
    Returns:
      The object referenced by the given name.
    Raises:
      ImportError: If the module not not be imported.
      AttributeError: If the symbol could not be found in the module.
    """
    comps = dotted_name.split('.')
    module_name = '.'.join(comps[:-1])
    symbol_name = comps[-1]
    module = importlib.import_module(module_name)
    return getattr(module, symbol_name)

beancount.utils.invariants

用于在类的方法上注册辅助函数以检查不变量的函数。

此功能旨在用于测试中,测试会设置一个类,使其在每次函数调用前后自动运行不变量验证函数,以确保一些仅在测试中使用的额外合理性检查。

示例:使用 check_inventory_invariants() 函数对 Inventory 类进行仪器化。

def setUp(module): instrument_invariants(Inventory, check_inventory_invariants, check_inventory_invariants)

def tearDown(module): uninstrument_invariants(Inventory)

beancount.utils.invariants.instrument_invariants(klass, prefun, postfun)

为类 'klass' 注入前置和后置不变量检查函数。

参数:
  • klass – 一个类对象,其方法将被仪器化。

  • prefun – 一个在调用前检查不变量的函数。

  • postfun – 一个在调用前检查不变量的函数。

源代码位于 beancount/utils/invariants.py
def instrument_invariants(klass, prefun, postfun):
    """Instrument the class 'klass' with pre/post invariant
    checker functions.

    Args:
      klass: A class object, whose methods to be instrumented.
      prefun: A function that checks invariants pre-call.
      postfun: A function that checks invariants pre-call.
    """
    instrumented = {}
    for attrname, object_ in klass.__dict__.items():
        if attrname.startswith('_'):
            continue
        if not isinstance(object_, types.FunctionType):
            continue
        instrumented[attrname] = object_
        setattr(klass, attrname,
                invariant_check(object_, prefun, postfun))
    klass.__instrumented = instrumented

beancount.utils.invariants.invariant_check(method, prefun, postfun)

为方法装饰预/后不变量检查器。

参数:
  • method – 待 instrument 的无绑定方法。

  • prefun – 一个在调用前检查不变量的函数。

  • postfun – 一个在调用后检查不变量的函数。

返回:
  • 一个经过装饰的无绑定方法。

源代码位于 beancount/utils/invariants.py
def invariant_check(method, prefun, postfun):
    """Decorate a method with the pre/post invariant checkers.

    Args:
      method: An unbound method to instrument.
      prefun: A function that checks invariants pre-call.
      postfun: A function that checks invariants post-call.
    Returns:
      An unbound method, decorated.
    """
    reentrant = []
    def new_method(self, *args, **kw):
        reentrant.append(None)
        if len(reentrant) == 1:
            prefun(self)
        result = method(self, *args, **kw)
        if len(reentrant) == 1:
            postfun(self)
        reentrant.pop()
        return result
    return new_method

beancount.utils.invariants.uninstrument_invariants(klass)

撤销对不变量的 instrumentation。

参数:
  • klass – 要撤销 instrumentation 的方法所属的类对象。

源代码位于 beancount/utils/invariants.py
def uninstrument_invariants(klass):
    """Undo the instrumentation for invariants.

    Args:
      klass: A class object, whose methods to be uninstrumented.
    """
    instrumented = getattr(klass, '__instrumented', None)
    if instrumented:
        for attrname, object_ in instrumented.items():
            setattr(klass, attrname, object_)
    del klass.__instrumented

beancount.utils.memo

记忆化工具。

beancount.utils.memo.memoize_recent_fileobj(function, cache_filename, expiration=None)

对返回文件对象的给定函数进行记忆化,缓存最近的调用结果。

缓存的结果会在一段时间后过期。

参数:
  • function – 一个可调用对象。

  • cache_filename – 一个字符串,表示用于缓存的数据库文件路径。

  • expiration – 结果保持有效的时长。使用 'None' 表示永不过期(默认值)。

返回:
  • 该函数的记忆化版本。

源代码位于 beancount/utils/memo.py
def memoize_recent_fileobj(function, cache_filename, expiration=None):
    """Memoize recent calls to the given function which returns a file object.

    The results of the cache expire after some time.

    Args:
      function: A callable object.
      cache_filename: A string, the path to the database file to cache to.
      expiration: The time during which the results will be kept valid. Use
        'None' to never expire the cache (this is the default).
    Returns:
      A memoized version of the function.
    """
    urlcache = shelve.open(cache_filename, 'c')
    urlcache.lock = threading.Lock()  # Note: 'shelve' is not thread-safe.
    @functools.wraps(function)
    def memoized(*args, **kw):
        # Encode the arguments, including a date string in order to invalidate
        # results over some time.
        md5 = hashlib.md5()
        md5.update(str(args).encode('utf-8'))
        md5.update(str(sorted(kw.items())).encode('utf-8'))

        hash_ = md5.hexdigest()
        time_now = now()
        try:
            with urlcache.lock:
                time_orig, contents = urlcache[hash_]
            if expiration is not None and (time_now - time_orig) > expiration:
                raise KeyError
        except KeyError:
            fileobj = function(*args, **kw)
            if fileobj:
                contents = fileobj.read()
                with urlcache.lock:
                    urlcache[hash_] = (time_now, contents)
            else:
                contents = None

        return io.BytesIO(contents) if contents else None
    return memoized

beancount.utils.memo.now()

对 datetime.datetime.now() 的封装,便于测试。

源代码位于 beancount/utils/memo.py
def now():
    "Indirection on datetime.datetime.now() for testing."
    return datetime.datetime.now()

beancount.utils.misc_utils

通用工具包和函数。

beancount.utils.misc_utils.LineFileProxy

一个文件对象,将完整行的写入委托给另一个日志函数。可用于在无需手动处理行的情况下,将数据写入特定日志级别。

beancount.utils.misc_utils.LineFileProxy.__init__(self, line_writer, prefix=None, write_newlines=False) 特殊

构造一个新的行委托文件对象代理。

参数:
  • line_writer – 一个可调用函数,用于将内容写入委托输出。

  • prefix – 可选字符串,将在每一行前插入的前缀。

  • write_newlines – 布尔值,若为 true 则输出换行符。

源代码位于 beancount/utils/misc_utils.py
def __init__(self, line_writer, prefix=None, write_newlines=False):
    """Construct a new line delegator file object proxy.

    Args:
      line_writer: A callable function, used to write to the delegated output.
      prefix: An optional string, the prefix to insert before every line.
      write_newlines: A boolean, true if we should output the newline characters.
    """
    self.line_writer = line_writer
    self.prefix = prefix
    self.write_newlines = write_newlines
    self.data = []

beancount.utils.misc_utils.LineFileProxy.close(self)

关闭行委托器。

源代码位于 beancount/utils/misc_utils.py
def close(self):
    """Close the line delegator."""
    self.flush()

beancount.utils.misc_utils.LineFileProxy.flush(self)

将数据刷新到行写入器。

源代码位于 beancount/utils/misc_utils.py
def flush(self):
    """Flush the data to the line writer."""
    data = ''.join(self.data)
    if data:
        lines = data.splitlines()
        self.data = [lines.pop(-1)] if data[-1] != '\n' else []
        for line in lines:
            if self.prefix:
                line = self.prefix + line
            if self.write_newlines:
                line += '\n'
            self.line_writer(line)

beancount.utils.misc_utils.LineFileProxy.write(self, data)

将某些字符串数据写入输出。

参数:
  • data – 一个字符串,可能包含或不包含换行符。

源代码位于 beancount/utils/misc_utils.py
def write(self, data):
    """Write some string data to the output.

    Args:
      data: A string, with or without newlines.
    """
    if '\n' in data:
        self.data.append(data)
        self.flush()
    else:
        self.data.append(data)

beancount.utils.misc_utils.TypeComparable

一个基类,其相等性比较包含对实例自身类型的比较。

beancount.utils.misc_utils.box(name=None, file=None)

一个上下文管理器,用于在代码块周围打印边框。这在测试中以可读方式打印内容时非常有用。

参数:
  • name – 字符串,用于指定边框的名称。

  • file – 用于打印的文件对象。

返回:无。

源代码位于 beancount/utils/misc_utils.py
@contextlib.contextmanager
def box(name=None, file=None):
    """A context manager that prints out a box around a block.
    This is useful for printing out stuff from tests in a way that is readable.

    Args:
      name: A string, the name of the box to use.
      file: The file object to print to.
    Yields:
      None.
    """
    file = file or sys.stdout
    file.write('\n')
    if name:
        header = ',--------({})--------\n'.format(name)
        footer = '`{}\n'.format('-' * (len(header)-2))
    else:
        header = ',----------------\n'
        footer = '`----------------\n'

    file.write(header)
    yield
    file.write(footer)
    file.flush()

beancount.utils.misc_utils.cmptuple(name, attributes)

创建一个可比较的 namedtuple 类,类似于 collections.namedtuple。

可比较的命名元组是一个元组,当内容相等但数据类型不同时,比较结果为 False。我们定义此功能是为了补充 collections.namedtuple,因为默认情况下命名元组会忽略类型,而我们希望在测试中进行精确比较。

参数:
  • name – 类的指定名称。

  • attributes – 字符串或字符串元组,表示属性名称。

返回:
  • 一个新的继承自 namedtuple 的类型,当与其他具有相同内容的元组比较时返回 False。

源代码位于 beancount/utils/misc_utils.py
def cmptuple(name, attributes):
    """Manufacture a comparable namedtuple class, similar to collections.namedtuple.

    A comparable named tuple is a tuple which compares to False if contents are
    equal but the data types are different. We define this to supplement
    collections.namedtuple because by default a namedtuple disregards the type
    and we want to make precise comparisons for tests.

    Args:
      name: The given name of the class.
      attributes: A string or tuple of strings, with the names of the
        attributes.
    Returns:
      A new namedtuple-derived type that compares False with other
      tuples with same contents.
    """
    base = collections.namedtuple('_{}'.format(name), attributes)
    return type(name, (TypeComparable, base,), {})

beancount.utils.misc_utils.compute_unique_clean_ids(strings)

给定一组字符串,将其转换为不含特殊字符的对应 ID,并确保 ID 列表唯一。返回结果的 (id, string) 对序列。

参数:
  • strings – 字符串列表。

返回:
  • 由 (id, string) 对组成的列表。

源代码位于 beancount/utils/misc_utils.py
def compute_unique_clean_ids(strings):
    """Given a sequence of strings, reduce them to corresponding ids without any
    funny characters and insure that the list of ids is unique. Yields pairs
    of (id, string) for the result.

    Args:
      strings: A list of strings.
    Returns:
      A list of (id, string) pairs.
    """
    string_set = set(strings)

    # Try multiple methods until we get one that has no collisions.
    for regexp, replacement in [(r'[^A-Za-z0-9.-]', '_'),
                                (r'[^A-Za-z0-9_]', ''),]:
        seen = set()
        idmap = {}
        mre = re.compile(regexp)
        for string in string_set:
            id_ = mre.sub(replacement, string)
            if id_ in seen:
                break  # Collision.
            seen.add(id_)
            idmap[id_] = string
        else:
            break
    else:
        return None # Could not find a unique mapping.

    return idmap

beancount.utils.misc_utils.deprecated(message)

一个装饰器生成器,用于标记函数为已弃用并记录警告。

源代码位于 beancount/utils/misc_utils.py
def deprecated(message):
    """A decorator generator to mark functions as deprecated and log a warning."""
    def decorator(func):
        @functools.wraps(func)
        def new_func(*args, **kwargs):
            warnings.warn("Call to deprecated function {}: {}".format(func.__name__,
                                                                      message),
                          category=DeprecationWarning,
                          stacklevel=2)
            return func(*args, **kwargs)
        return new_func
    return decorator

beancount.utils.misc_utils.dictmap(mdict, keyfun=None, valfun=None)

映射字典的值。

参数:
  • mdict – 一个字典。

  • key – 用于应用到键的可调用对象。

  • value – 用于应用到值的可调用对象。

源代码位于 beancount/utils/misc_utils.py
def dictmap(mdict, keyfun=None, valfun=None):
    """Map a dictionary's value.

    Args:
      mdict: A dict.
      key: A callable to apply to the keys.
      value: A callable to apply to the values.
    """
    if keyfun is None:
        keyfun = lambda x: x
    if valfun is None:
        valfun = lambda x: x
    return {keyfun(key): valfun(val) for key, val in mdict.items()}

beancount.utils.misc_utils.escape_string(string)

转义收款人和描述中的引号和反斜杠。

参数:
  • string – 任意字符串。

返回:输入字符串,其中违规字符已被替换。

源代码位于 beancount/utils/misc_utils.py
def escape_string(string):
    """Escape quotes and backslashes in payee and narration.

    Args:
      string: Any string.
    Returns.
      The input string, with offending characters replaced.
    """
    return string.replace('\\', r'\\')\
                 .replace('"', r'\"')

beancount.utils.misc_utils.filter_type(elist, types)

过滤给定列表,仅保留指定类型的实例。

参数:
  • elist – 元素序列。

  • types – 要包含在输出列表中的类型序列。

生成:如果元素是 'types' 的实例,则输出该元素。

源代码位于 beancount/utils/misc_utils.py
def filter_type(elist, types):
    """Filter the given list to yield only instances of the given types.

    Args:
      elist: A sequence of elements.
      types: A sequence of types to include in the output list.
    Yields:
      Each element, if it is an instance of 'types'.
    """
    for element in elist:
        if not isinstance(element, types):
            continue
        yield element

beancount.utils.misc_utils.first_paragraph(docstring)

返回文档字符串的第一句话。该句子必须由空行分隔。

参数:
  • docstring – 一个文档字符串。

返回:
  • 一个仅包含第一句话的字符串,位于单行上。

源代码位于 beancount/utils/misc_utils.py
def first_paragraph(docstring):
    """Return the first sentence of a docstring.
    The sentence has to be delimited by an empty line.

    Args:
      docstring: A doc string.
    Returns:
      A string with just the first sentence on a single line.
    """
    lines = []
    for line in docstring.strip().splitlines():
        if not line:
            break
        lines.append(line.rstrip())
    return ' '.join(lines)

beancount.utils.misc_utils.get_screen_height()

返回运行此程序的终端高度。

返回:
  • 一个整数,表示屏幕的高度(字符数)。如果终端无法初始化,则返回 0。

源代码位于 beancount/utils/misc_utils.py
def get_screen_height():
    """Return the height of the terminal that runs this program.

    Returns:
      An integer, the number of characters the screen is high.
      Return 0 if the terminal cannot be initialized.
    """
    return _get_screen_value('lines', 0)

beancount.utils.misc_utils.get_screen_width()

返回运行此程序的终端宽度。

返回:
  • 一个整数,表示屏幕的宽度(字符数)。如果终端无法初始化,则返回 0。

源代码位于 beancount/utils/misc_utils.py
def get_screen_width():
    """Return the width of the terminal that runs this program.

    Returns:
      An integer, the number of characters the screen is wide.
      Return 0 if the terminal cannot be initialized.
    """
    return _get_screen_value('cols', 0)

beancount.utils.misc_utils.get_tuple_values(ntuple, predicate, memo=None)

返回此命名元组实例中满足给定谓词的所有成员。此函数还会递归处理其成员中的列表或元组,因此可用于 Transaction 实例。

参数:
  • ntuple – 一个元组或命名元组。

  • predicate – 一个谓词函数,当属性应被输出时返回 True。

  • memo – 可选的记忆字典。如果某个元组已见过,则避免递归。

生成:如果谓词为真,则输出元组及其子元素的属性。

源代码位于 beancount/utils/misc_utils.py
def get_tuple_values(ntuple, predicate, memo=None):
    """Return all members referred to by this namedtuple instance that satisfy the
    given predicate. This function also works recursively on its members which
    are lists or tuples, and so it can be used for Transaction instances.

    Args:
      ntuple: A tuple or namedtuple.
      predicate: A predicate function that returns true if an attribute is to be
        output.
      memo: An optional memoizing dictionary. If a tuple has already been seen, the
        recursion will be avoided.
    Yields:
      Attributes of the tuple and its sub-elements if the predicate is true.
    """
    if memo is None:
        memo = set()
    id_ntuple = id(ntuple)
    if id_ntuple in memo:
        return
    memo.add(id_ntuple)

    if predicate(ntuple):
        yield
    for attribute in ntuple:
        if predicate(attribute):
            yield attribute
        if isinstance(attribute, (list, tuple)):
            for value in get_tuple_values(attribute, predicate, memo):
                yield value

beancount.utils.misc_utils.groupby(keyfun, elements)

将元素分组为字典形式的列表,其中键由函数 'keyfun' 计算得出。

参数:
  • keyfun – 一个可调用对象,用于从每个元素中获取分组键。

  • elements – 待分组的元素可迭代对象。

返回:
  • 一个将键映射到序列列表的字典。

源代码位于 beancount/utils/misc_utils.py
def groupby(keyfun, elements):
    """Group the elements as a dict of lists, where the key is computed using the
    function 'keyfun'.

    Args:
      keyfun: A callable, used to obtain the group key from each element.
      elements: An iterable of the elements to group.
    Returns:
      A dict of key to list of sequences.
    """
    # Note: We could allow a custom aggregation function. Another option is
    # provide another method to reduce the list values of a dict, but that can
    # be accomplished using a dict comprehension.
    grouped = defaultdict(list)
    for element in elements:
        grouped[keyfun(element)].append(element)
    return grouped

beancount.utils.misc_utils.idify(string)

将文件名中不允许的字符替换为下划线。

参数:
  • string – 任意字符串。

返回:
  • 输入字符串,其中违规字符已被替换。

源代码位于 beancount/utils/misc_utils.py
def idify(string):
    """Replace characters objectionable for a filename with underscores.

    Args:
      string: Any string.
    Returns:
      The input string, with offending characters replaced.
    """
    for sfrom, sto in [(r'[ \(\)]+', '_'),
                       (r'_*\._*', '.')]:
        string = re.sub(sfrom, sto, string)
    string = string.strip('_')
    return string

beancount.utils.misc_utils.import_curses()

尝试导入 'curses' 模块。(此处用于在测试中覆盖该模块。)

返回:
  • 如果成功导入,则返回 curses 模块。

异常:
  • ImportError – 如果无法导入该模块。

源代码位于 beancount/utils/misc_utils.py
def import_curses():
    """Try to import the 'curses' module.
    (This is used here in order to override for tests.)

    Returns:
      The curses module, if it was possible to import it.
    Raises:
      ImportError: If the module could not be imported.
    """
    # Note: There's a recipe for getting terminal size on Windows here, without
    # curses, I should probably implement that at some point:
    # https://stackoverflow.com/questions/263890/how-do-i-find-the-width-height-of-a-terminal-window
    # Also, consider just using 'blessings' instead, which provides this across
    # multiple platforms.
    # pylint: disable=import-outside-toplevel
    import curses
    return curses

beancount.utils.misc_utils.is_sorted(iterable, key= at 0x7f6b92fae700>, cmp= at 0x7f6b92fae7a0>)

如果序列已排序,则返回 True。

参数:
  • iterable – 一个可迭代序列。

  • key – 用于提取排序依据值的函数。

  • cmp – 用于比较序列中两个元素的函数。

返回:
  • 布尔值,若序列已排序则为 True。

源代码位于 beancount/utils/misc_utils.py
def is_sorted(iterable, key=lambda x: x, cmp=lambda x, y: x <= y):
    """Return true if the sequence is sorted.

    Args:
      iterable: An iterable sequence.
      key: A function to extract the quantity by which to sort.
      cmp: A function that compares two elements of a sequence.
    Returns:
      A boolean, true if the sequence is sorted.
    """
    iterator = map(key, iterable)
    prev = next(iterator)
    for element in iterator:
        if not cmp(prev, element):
            return False
        prev = element
    return True

beancount.utils.misc_utils.log_time(operation_name, log_timings, indent=0)

一个上下文管理器,用于计时代码块并将结果以 info 级别记录。

参数:
  • operation_name – 字符串,操作的名称标签。

  • log_timings – 用于写入日志消息的函数。若设为 None,则不记录计时信息(此操作将无任何效果)。

  • indent – 整数,用于格式化计时行的缩进级别。在记录分层操作的计时信息时非常有用。

返回:操作的开始时间。

源代码位于 beancount/utils/misc_utils.py
@contextlib.contextmanager
def log_time(operation_name, log_timings, indent=0):
    """A context manager that times the block and logs it to info level.

    Args:
      operation_name: A string, a label for the name of the operation.
      log_timings: A function to write log messages to. If left to None,
        no timings are written (this becomes a no-op).
      indent: An integer, the indentation level for the format of the timing
        line. This is useful if you're logging timing to a hierarchy of
        operations.
    Yields:
      The start time of the operation.
    """
    time1 = time()
    yield time1
    time2 = time()
    if log_timings:
        log_timings("Operation: {:48} Time: {}{:6.0f} ms".format(
            "'{}'".format(operation_name), '      '*indent, (time2 - time1) * 1000))

beancount.utils.misc_utils.longest(seq)

返回给定子序列中最长的一个。

参数:
  • seq – 一个包含列表的可迭代序列。

返回:
  • 序列中最长的列表。

源代码位于 beancount/utils/misc_utils.py
def longest(seq):
    """Return the longest of the given subsequences.

    Args:
      seq: An iterable sequence of lists.
    Returns:
      The longest list from the sequence.
    """
    longest, length = None, -1
    for element in seq:
        len_element = len(element)
        if len_element > length:
            longest, length = element, len_element
    return longest

beancount.utils.misc_utils.map_namedtuple_attributes(attributes, mapper, object_)

使用 mapper 映射对象的命名属性值。

参数:
  • attributes – 字符串序列,表示要映射的属性名称。

  • mapper – 一个可调用对象,接受字段值并返回新值。

  • object_ – 一个具有属性的 namedtuple 对象。

返回:
  • 一个具有相同 namedtuple 结构的新实例,其命名字段通过 mapper 映射。

源代码位于 beancount/utils/misc_utils.py
def map_namedtuple_attributes(attributes, mapper, object_):
    """Map the value of the named attributes of object by mapper.

    Args:
      attributes: A sequence of string, the attribute names to map.
      mapper: A callable that accepts the value of a field and returns
        the new value.
      object_: Some namedtuple object with attributes on it.
    Returns:
      A new instance of the same namedtuple with the named fields mapped by
      mapper.
    """
    return object_._replace(**{attribute: mapper(getattr(object_, attribute))
                               for attribute in attributes})

beancount.utils.misc_utils.replace_namedtuple_values(ntuple, predicate, mapper, memo=None)

递归遍历所有 namedtuple 和列表的成员,对于匹配指定 predicate 的成员,使用给定的 mapper 进行处理。

参数:
  • ntuple – 一个 namedtuple 实例。

  • predicate – 一个谓词函数,当属性应被输出时返回 True。

  • mapper – 一个可调用对象,接受单个参数并返回其替换值。

  • memo – 可选的记忆字典。如果某个元组已见过,则避免递归。

生成:如果谓词为真,则输出元组及其子元素的属性。

源代码位于 beancount/utils/misc_utils.py
def replace_namedtuple_values(ntuple, predicate, mapper, memo=None):
    """Recurse through all the members of namedtuples and lists, and for
    members that match the given predicate, run them through the given mapper.

    Args:
      ntuple: A namedtuple instance.
      predicate: A predicate function that returns true if an attribute is to be
        output.
      mapper: A callable, that will accept a single argument and return its
        replacement value.
      memo: An optional memoizing dictionary. If a tuple has already been seen, the
        recursion will be avoided.
    Yields:
      Attributes of the tuple and its sub-elements if the predicate is true.
    """
    if memo is None:
        memo = set()
    id_ntuple = id(ntuple)
    if id_ntuple in memo:
        return None
    memo.add(id_ntuple)

    # pylint: disable=unidiomatic-typecheck
    if not (type(ntuple) is not tuple and isinstance(ntuple, tuple)):
        return ntuple
    replacements = {}
    for attribute_name, attribute in zip(ntuple._fields, ntuple):
        if predicate(attribute):
            replacements[attribute_name] = mapper(attribute)
        elif type(attribute) is not tuple and isinstance(attribute, tuple):
            replacements[attribute_name] = replace_namedtuple_values(
                attribute, predicate, mapper, memo)
        elif type(attribute) in (list, tuple):
            replacements[attribute_name] = [
                replace_namedtuple_values(member, predicate, mapper, memo)
                for member in attribute]
    return ntuple._replace(**replacements)

beancount.utils.misc_utils.skipiter(iterable, num_skip)

跳过迭代器中的若干元素。

参数:
  • iterable – 一个迭代器。

  • num_skip – 要跳过的元素数量。

生成:从迭代器中输出元素,跳过 num_skip 个元素。例如,skipiter(range(10), 3) 生成 [0, 3, 6, 9]。

源代码位于 beancount/utils/misc_utils.py
def skipiter(iterable, num_skip):
    """Skip some elements from an iterator.

    Args:
      iterable: An iterator.
      num_skip: The number of elements in the period.
    Yields:
      Elements from the iterable, with num_skip elements skipped.
      For example, skipiter(range(10), 3) yields [0, 3, 6, 9].
    """
    assert num_skip > 0
    sit = iter(iterable)
    while 1:
        try:
            value = next(sit)
        except StopIteration:
            return
        yield value
        for _ in range(num_skip-1):
            try:
                next(sit)
            except StopIteration:
                return

beancount.utils.misc_utils.sorted_uniquify(iterable, keyfunc=None, last=False)

给定一个元素序列,根据指定的键进行排序并去重。保留键值相同的元素序列中的第一个或最后一个元素(由 key 决定)。此函数保持原始元素的顺序,而是按键排序后返回。

参数:
  • iterable – 一个可迭代序列。

  • keyfunc – 一个函数,用于从元素中提取用作排序和去重依据的键。若未指定,则使用恒等函数,直接对元素本身进行去重。

  • last – 布尔值,若为 True,则保留相同键的最后一个元素;否则保留第一个。

生成:来自迭代器的元素。

源代码位于 beancount/utils/misc_utils.py
def sorted_uniquify(iterable, keyfunc=None, last=False):
    """Given a sequence of elements, sort and remove duplicates of the given key.
    Keep either the first or the last (by key) element of a sequence of
    key-identical elements. This does _not_ maintain the ordering of the
    original elements, they are returned sorted (by key) instead.

    Args:
      iterable: An iterable sequence.
      keyfunc: A function that extracts from the elements the sort key
        to use and uniquify on. If left unspecified, the identify function
        is used and the uniquification occurs on the elements themselves.
      last: A boolean, True if we should keep the last item of the same keys.
        Otherwise keep the first.
    Yields:
      Elements from the iterable.
    """
    if keyfunc is None:
        keyfunc = lambda x: x
    if last:
        prev_obj = UNSET
        prev_key = UNSET
        for obj in sorted(iterable, key=keyfunc):
            key = keyfunc(obj)
            if key != prev_key and prev_obj is not UNSET:
                yield prev_obj
            prev_obj = obj
            prev_key = key
        if prev_obj is not UNSET:
            yield prev_obj
    else:
        prev_key = UNSET
        for obj in sorted(iterable, key=keyfunc):
            key = keyfunc(obj)
            if key != prev_key:
                yield obj
                prev_key = key

beancount.utils.misc_utils.staticvar(varname, initial_value)

返回一个装饰器,用于定义 Python 函数的属性。

此功能用于在 Python 中模拟静态函数变量。

参数:
  • varname – 字符串,要定义的变量名称。

  • initial_value – 变量的初始值。

返回:
  • 一个函数装饰器。

源代码位于 beancount/utils/misc_utils.py
def staticvar(varname, initial_value):
    """Returns a decorator that defines a Python function attribute.

    This is used to simulate a static function variable in Python.

    Args:
      varname: A string, the name of the variable to define.
      initial_value: The value to initialize the variable to.
    Returns:
      A function decorator.
    """
    def deco(fun):
        setattr(fun, varname, initial_value)
        return fun
    return deco

beancount.utils.misc_utils.swallow(*exception_types)

捕获并忽略某些异常。

参数:
  • exception_types – 要忽略的异常类组成的元组。

返回:无。

源代码位于 beancount/utils/misc_utils.py
@contextlib.contextmanager
def swallow(*exception_types):
    """Catch and ignore certain exceptions.

    Args:
      exception_types: A tuple of exception classes to ignore.
    Yields:
      None.
    """
    try:
        yield
    except Exception as exc:
        if not isinstance(exc, exception_types):
            raise

beancount.utils.misc_utils.uniquify(iterable, keyfunc=None, last=False)

给定一个元素序列,根据指定的键去除重复项。保留键值相同的元素序列中的第一个或最后一个元素。尽可能保持原始顺序。此函数保持原始元素的顺序,返回的元素顺序与原始序列一致。

参数:
  • iterable – 一个可迭代序列。

  • keyfunc – 一个函数,用于从元素中提取用作排序和去重依据的键。若未指定,则使用恒等函数,直接对元素本身进行去重。

  • last – 布尔值,若为 True,则保留相同键的最后一个元素;否则保留第一个。

生成:来自迭代器的元素。

源代码位于 beancount/utils/misc_utils.py
def uniquify(iterable, keyfunc=None, last=False):
    """Given a sequence of elements, remove duplicates of the given key. Keep either
    the first or the last element of a sequence of key-identical elements. Order
    is maintained as much as possible. This does maintain the ordering of the
    original elements, they are returned in the same order as the original
    elements.

    Args:
      iterable: An iterable sequence.
      keyfunc: A function that extracts from the elements the sort key
        to use and uniquify on. If left unspecified, the identify function
        is used and the uniquification occurs on the elements themselves.
      last: A boolean, True if we should keep the last item of the same keys.
        Otherwise keep the first.
    Yields:
      Elements from the iterable.
    """
    if keyfunc is None:
        keyfunc = lambda x: x
    seen = set()
    if last:
        unique_reversed_list = []
        for obj in reversed(iterable):
            key = keyfunc(obj)
            if key not in seen:
                seen.add(key)
                unique_reversed_list.append(obj)
        yield from reversed(unique_reversed_list)
    else:
        for obj in iterable:
            key = keyfunc(obj)
            if key not in seen:
                seen.add(key)
                yield obj

beancount.utils.pager

将输出写入分页器的代码。

本模块包含一个对象,它会累积行内容,直到达到最小阈值,然后决定:若行数低于阈值,则直接输出到原始输出(不使用分页器);若超过阈值,则创建一个分页器并将行内容输出到其中,后续所有行都将转发至该分页器。该对象的目的是仅在要打印的行数超过最小阈值时才将输出重定向到分页器。

上下文管理器旨在将输出管道传输到分页器,并等待分页器完成后再继续执行。只需向文件对象写入内容,退出时我们会自动关闭该文件对象。这还能屏蔽用户退出子进程时触发的管道断裂错误,并在分页器命令失败时回退到使用标准输出。

beancount.utils.pager.ConditionalPager

一个代理文件对象,仅在写入的行数达到最小阈值后才启动分页器。

beancount.utils.pager.ConditionalPager.__enter__(self) 特殊

初始化上下文管理器,并返回该实例本身。

源代码位于 beancount/utils/pager.py
def __enter__(self):
    """Initialize the context manager and return this instance as it."""

    # The file and pipe object we're writing to. This gets set after the
    # number of accumulated lines reaches the threshold.
    if self.minlines:
        self.file = None
        self.pipe = None
    else:
        self.file, self.pipe = create_pager(self.command, self.default_file)

    # Lines accumulated before the threshold.
    self.accumulated_data = []
    self.accumulated_lines = 0

    # Return this object to be used as the context manager itself.
    return self

beancount.utils.pager.ConditionalPager.__exit__(self, type, value, unused_traceback) 特殊

上下文管理器退出。此方法会将输出刷新到输出文件。

参数:
  • type – 可选的异常类型,符合上下文管理器规范。

  • value – 可选的异常值,符合上下文管理器规范。

  • unused_traceback – 可选的跟踪信息。

源代码位于 beancount/utils/pager.py
def __exit__(self, type, value, unused_traceback):
    """Context manager exit. This flushes the output to our output file.

    Args:
      type: Optional exception type, as per context managers.
      value: Optional exception value, as per context managers.
      unused_traceback: Optional trace.
    """
    try:
        if self.file:
            # Flush the output file and close it.
            self.file.flush()
        else:
            # Oops... we never reached the threshold. Flush the accumulated
            # output to the file.
            self.flush_accumulated(self.default_file)

        # Wait for the subprocess (if we have one).
        if self.pipe:
            self.file.close()
            self.pipe.wait()

    # Absorb broken pipes that may occur on flush or close above.
    except BrokenPipeError:
        return True

    # Absorb broken pipes.
    if isinstance(value, BrokenPipeError):
        return True
    elif value:
        raise

beancount.utils.pager.ConditionalPager.__init__(self, command, minlines=None) 特殊

创建一个条件分页器。

参数:
  • command – 字符串,作为分页器运行的 shell 命令。

  • minlines – 如果设置,则表示在行数低于该值时无需启动分页器。这可以避免在屏幕高度足以完整显示内容时启动分页器。若未设置,则始终启动分页器(这也是合理的行为)。

源代码位于 beancount/utils/pager.py
def __init__(self, command, minlines=None):
    """Create a conditional pager.

    Args:
      command: A string, the shell command to run as a pager.
      minlines: If set, the number of lines under which you should not bother starting
        a pager. This avoids kicking off a pager if the screen is high enough to
        render the contents. If the value is unset, always starts a pager (which is
        fine behavior too).
    """
    self.command = command
    self.minlines = minlines
    self.default_file = (codecs.getwriter("utf-8")(sys.stdout.buffer)
                         if hasattr(sys.stdout, 'buffer') else
                         sys.stdout)

beancount.utils.pager.ConditionalPager.flush_accumulated(self, file)

将已累积的行刷新到新创建的分页器中。同时禁用累积器。

参数:
  • file – 要将累积数据刷新到的文件对象。

源代码位于 beancount/utils/pager.py
def flush_accumulated(self, file):
    """Flush the existing lines to the newly created pager.
    This also disabled the accumulator.

    Args:
      file: A file object to flush the accumulated data to.
    """
    if self.accumulated_data:
        write = file.write
        for data in self.accumulated_data:
            write(data)
    self.accumulated_data = None
    self.accumulated_lines = None

beancount.utils.pager.ConditionalPager.write(self, data)

写入数据。重写了文件对象接口的方法。

参数:
  • data – 字符串,要写入输出的数据。

源代码位于 beancount/utils/pager.py
def write(self, data):
    """Write the data out. Overridden from the file object interface.

    Args:
      data: A string, data to write to the output.
    """
    if self.file is None:
        # Accumulate the new lines.
        self.accumulated_lines += data.count('\n')
        self.accumulated_data.append(data)

        # If we've reached the threshold, create a file.
        if self.accumulated_lines > self.minlines:
            self.file, self.pipe = create_pager(self.command, self.default_file)
            self.flush_accumulated(self.file)
    else:
        # We've already created a pager subprocess... flush the lines to it.
        self.file.write(data)
        # try:
        # except BrokenPipeError:
        #     # Make sure we don't barf on __exit__().
        #     self.file = self.pipe = None
        #     raise

beancount.utils.pager.create_pager(command, file)

尝试创建并返回一个分页器子进程。

参数:
  • command – 字符串,作为分页器运行的 shell 命令。

  • file – 分页器写入的目标文件对象。如果无法创建分页器子进程,则使用此对象作为默认输出。

返回:
  • 一个包含 (file, pipe) 的元组,其中 file 是文件对象,pipe 是可选的 subprocess.Popen 实例,用于等待。如果未能创建子进程,则 pipe 可能为 None。

源代码位于 beancount/utils/pager.py
def create_pager(command, file):
    """Try to create and return a pager subprocess.

    Args:
      command: A string, the shell command to run as a pager.
      file: The file object for the pager write to. This is also used as a
        default if we failed to create the pager subprocess.
    Returns:
      A pair of (file, pipe), a file object and an optional subprocess.Popen instance
      to wait on. The pipe instance may be set to None if we failed to create a subprocess.
    """

    if command is None:
        command = os.environ.get('PAGER', DEFAULT_PAGER)
    if not command:
        command = DEFAULT_PAGER

    pipe = None

    # In case of using 'less', make sure the charset is set properly. In theory
    # you could override this by setting PAGER to "LESSCHARSET=utf-8 less" but
    # this shouldn't affect other programs and is unlikely to cause problems, so
    # we set it here to make default behavior work for most people (we always
    # write UTF-8).
    env = os.environ.copy()
    env['LESSCHARSET'] = "utf-8"

    try:
        pipe = subprocess.Popen(command, shell=True,
                                stdin=subprocess.PIPE,
                                stdout=file,
                                env=env)
    except OSError as exc:
        logging.error("Invalid pager: {}".format(exc))
    else:
        stdin_wrapper = io.TextIOWrapper(pipe.stdin, 'utf-8')
        file = stdin_wrapper
    return file, pipe

beancount.utils.pager.flush_only(fileobj)

围绕文件对象的上下文管理器,不关闭该文件。

用于返回一个不关闭文件的上下文管理器,而是仅刷新它。这在需要提供上述分页器类的替代方案时非常有用。

参数:
  • fileobj – 文件对象,在上下文管理器执行后仍保持打开状态。

返回:一个上下文管理器,yield 此对象。

源代码位于 beancount/utils/pager.py
@contextlib.contextmanager
def flush_only(fileobj):
    """A contextmanager around a file object that does not close the file.

    This is used to return a context manager on a file object but not close it.
    We flush it instead. This is useful in order to provide an alternative to a
    pager class as above.

    Args:
      fileobj: A file object, to remain open after running the context manager.
    Yields:
      A context manager that yields this object.
    """
    try:
        yield fileobj
    finally:
        fileobj.flush()

beancount.utils.snoop

文本操作工具。

beancount.utils.snoop.Snoop

一个仅保存函数返回值的窥探器可调用对象。这在条件语句中使用 re.match 和 re.search 时特别有用,例如:

snoop = Snoop() ... if snoop(re.match(r"(\d+)-(\d+)-(\d+)", text)): year, month, date = snoop.value.group(1, 2, 3)

属性:

名称 类型 描述
value

从函数调用中窥探到的最后一个值。

历史记录

如果指定了 'maxlen',则为最近窥探到的几个值。

beancount.utils.snoop.Snoop.__call__(self, value) 特殊

将值保存到窥探器中。此方法用于包装函数调用。

参数:
  • value – 要推送/保存的值。

返回:
  • 值本身。

源代码位于 beancount/utils/snoop.py
def __call__(self, value):
    """Save a value to the snooper. This is meant to wrap
    a function call.

    Args:
      value: The value to push/save.
    Returns:
      Value itself.
    """
    self.value = value
    if self.history is not None:
        self.history.append(value)
    return value

beancount.utils.snoop.Snoop.__getattr__(self, attr) 特殊

将属性转发到值上。

参数:
  • attr – 字符串,属性的名称。

返回:
  • 属性的值。

源代码位于 beancount/utils/snoop.py
def __getattr__(self, attr):
    """Forward the attribute to the value.

    Args:
      attr: A string, the name of the attribute.
    Returns:
      The value of the attribute.
    """
    return getattr(self.value, attr)

beancount.utils.snoop.Snoop.__init__(self, maxlen=None) 特殊

创建一个新的窥探器。

参数:
  • maxlen – 如果指定,则为一个整数,用于启用保存该数量的

源代码位于 beancount/utils/snoop.py
def __init__(self, maxlen=None):
    """Create a new snooper.

    Args:
      maxlen: If specified, an integer, which enables the saving of that
      number of last values in the history attribute.
    """
    self.value = None
    self.history = (collections.deque(maxlen=maxlen)
                    if maxlen
                    else None)

beancount.utils.snoop.snoopify(function)

将函数装饰为可窥探的。

此方法用于将现有函数重新赋值为其可窥探版本。例如,若希望 're.match' 自动变为可窥探的,只需像这样装饰它:

re.match = snoopify(re.match)

然后你就可以在条件语句中直接调用 're.match',并通过 're.match.value' 访问其最后一次返回的值。

源代码位于 beancount/utils/snoop.py
def snoopify(function):
    """Decorate a function as snoopable.

    This is meant to reassign existing functions to a snoopable version of them.
    For example, if you wanted 're.match' to be automatically snoopable, just
    decorate it like this:

      re.match = snoopify(re.match)

    and then you can just call 're.match' in a conditional and then access
    're.match.value' to get to the last returned value.
    """
    @functools.wraps(function)
    def wrapper(*args, **kw):
        value = function(*args, **kw)
        wrapper.value = value
        return value
    wrapper.value = None
    return wrapper

beancount.utils.table

表格渲染。

beancount.utils.table.Table (元组)

Table(columns, header, body)

beancount.utils.table.Table.__getnewargs__(self) 特殊

将自身返回为一个普通元组。供 copy 和 pickle 使用。

源代码位于 beancount/utils/table.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)

beancount.utils.table.Table.__new__(_cls, columns, header, body) 特殊 静态方法

创建 Table(columns, header, body) 的新实例

beancount.utils.table.Table.__replace__(/, self, **kwds) 特殊

返回一个新的 Table 对象,用指定的新值替换字段

源代码位于 beancount/utils/table.py
def _replace(self, /, **kwds):
    result = self._make(_map(kwds.pop, field_names, self))
    if kwds:
        raise TypeError(f'Got unexpected field names: {list(kwds)!r}')
    return result

beancount.utils.table.Table.__repr__(self) 特殊

返回一个格式良好的表示字符串

源代码位于 beancount/utils/table.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self

beancount.utils.table.attribute_to_title(fieldname)

将编程标识符转换为可读的字段名称。

参数:
  • fieldname – 一个字符串,表示编程标识符,例如 'book_value'。

返回:
  • 一个可读的字符串,例如 '账面价值'。

源代码位于 beancount/utils/table.py
def attribute_to_title(fieldname):
    """Convert programming id into readable field name.

    Args:
      fieldname: A string, a programming ids, such as 'book_value'.
    Returns:
      A readable string, such as 'Book Value.'
   """
    return fieldname.replace('_', ' ').title()

beancount.utils.table.compute_table_widths(rows)

计算一组行的最大字符宽度。

参数:
  • rows – 一组行,每行是字符串的序列。

返回:
  • 一个整数列表,表示渲染该表各列所需的最小宽度。

异常:
  • IndexError – 如果各行长度不一致。

源代码位于 beancount/utils/table.py
def compute_table_widths(rows):
    """Compute the max character widths of a list of rows.

    Args:
      rows: A list of rows, which are sequences of strings.
    Returns:
      A list of integers, the maximum widths required to render the columns of
      this table.
    Raises:
      IndexError: If the rows are of different lengths.
    """
    row_iter = iter(rows)
    first_row = next(row_iter)
    num_columns = len(first_row)
    column_widths = [len(cell) for cell in first_row]
    for row in row_iter:
        for i, cell in enumerate(row):
            if not isinstance(cell, str):
                cell = str(cell)
            cell_len = len(cell)
            if cell_len > column_widths[i]:
                column_widths[i] = cell_len
        if i+1 != num_columns:
            raise IndexError("Invalid number of rows")
    return column_widths

beancount.utils.table.create_table(rows, field_spec=None)

将元组列表转换为表格报告对象。

参数:
  • rows – 一个元组列表。

  • field_spec – 一个字符串列表,或由 (FIELDNAME-OR-INDEX, HEADER, FORMATTER-FUNCTION) 三元组组成的列表,用于选择要渲染的字段子集及其顺序。如果为字典,则值为用于渲染字段的函数;若函数设为 None,则直接对字段调用 str()。

返回:
  • 一个 Table 实例。

源代码位于 beancount/utils/table.py
def create_table(rows, field_spec=None):
    """Convert a list of tuples to an table report object.

    Args:
      rows: A list of tuples.
      field_spec: A list of strings, or a list of
        (FIELDNAME-OR-INDEX, HEADER, FORMATTER-FUNCTION)
        triplets, that selects a subset of the fields is to be rendered as well
        as their ordering. If this is a dict, the values are functions to call
        on the fields to render them. If a function is set to None, we will just
        call str() on the field.
    Returns:
      A Table instance.
    """
    # Normalize field_spec to a dict.
    if field_spec is None:
        namedtuple_class = type(rows[0])
        field_spec = [(field, None, None)
                      for field in namedtuple_class._fields]

    elif isinstance(field_spec, (list, tuple)):
        new_field_spec = []
        for field in field_spec:
            if isinstance(field, tuple):
                assert len(field) <= 3, field
                if len(field) == 1:
                    field = field[0]
                    new_field_spec.append((field, None, None))
                elif len(field) == 2:
                    field, header = field
                    new_field_spec.append((field, header, None))
                elif len(field) == 3:
                    new_field_spec.append(field)
            else:
                if isinstance(field, str):
                    title = attribute_to_title(field)
                elif isinstance(field, int):
                    title = "Field {}".format(field)
                else:
                    raise ValueError("Invalid type for column name")
                new_field_spec.append((field, title, None))

        field_spec = new_field_spec

    # Ensure a nicely formatted header.
    field_spec = [((name, attribute_to_title(name), formatter)
                   if header_ is None
                   else (name, header_, formatter))
                  for (name, header_, formatter) in field_spec]

    assert isinstance(field_spec, list), field_spec
    assert all(len(x) == 3 for x in field_spec), field_spec

    # Compute the column names.
    columns = [name for (name, _, __) in field_spec]

    # Compute the table header.
    header = [header_column for (_, header_column, __) in field_spec]

    # Compute the table body.
    body = []
    for row in rows:
        body_row = []
        for name, _, formatter in field_spec:
            if isinstance(name, str):
                value = getattr(row, name)
            elif isinstance(name, int):
                value = row[name]
            else:
                raise ValueError("Invalid type for column name")
            if value is not None:
                if formatter is not None:
                    value = formatter(value)
                else:
                    value = str(value)
            else:
                value = ''
            body_row.append(value)
        body.append(body_row)

    return Table(columns, header, body)

beancount.utils.table.render_table(table_, output, output_format, css_id=None, css_class=None)

将给定表格以请求的格式输出到输出文件对象。

表格将被写入 'output' 文件。

参数:
  • table_ – 一个 Table 实例。

  • output – 一个可写入的文件对象。

  • output_format – 一个字符串,指定表格输出格式,可选值为 'csv'、'txt' 或 'html'。

  • css_id – 一个字符串,表格对象的可选 CSS ID(仅用于 HTML)。

  • css_class – 一个字符串,表格对象的可选 CSS 类(仅用于 HTML)。

源代码位于 beancount/utils/table.py
def render_table(table_, output, output_format, css_id=None, css_class=None):
    """Render the given table to the output file object in the requested format.

    The table gets written out to the 'output' file.

    Args:
      table_: An instance of Table.
      output: A file object you can write to.
      output_format: A string, the format to write the table to,
        either 'csv', 'txt' or 'html'.
      css_id: A string, an optional CSS id for the table object (only used for HTML).
      css_class: A string, an optional CSS class for the table object (only used for HTML).
    """
    if output_format in ('txt', 'text'):
        text = table_to_text(table_, "  ", formats={'*': '>', 'account': '<'})
        output.write(text)

    elif output_format in ('csv',):
        table_to_csv(table_, file=output)

    elif output_format in ('htmldiv', 'html'):

        if output_format == 'html':
            output.write('<html>\n')
            output.write('<body>\n')

        output.write('<div id="{}">\n'.format(css_id) if css_id else '<div>\n')
        classes = [css_class] if css_class else None
        table_to_html(table_, file=output, classes=classes)
        output.write('</div>\n')

        if output_format == 'html':
            output.write('</body>\n')
            output.write('</html>\n')

    else:
        raise NotImplementedError("Unsupported format: {}".format(output_format))

beancount.utils.table.table_to_csv(table, file=None, **kwargs)

将 Table 输出为 CSV 文件。

参数:
  • table – 一个 Table 实例。

  • file – 用于写入的文件对象。如果未提供对象,则此函数返回一个字符串。

  • **kwargs – 传递给 csv.writer() 的可选参数。

返回:
  • 一个字符串,即渲染后的表格;或者如果提供了文件对象用于写入,则返回 None。

源代码位于 beancount/utils/table.py
def table_to_csv(table, file=None, **kwargs):
    """Render a Table to a CSV file.

    Args:
      table: An instance of a Table.
      file: A file object to write to. If no object is provided, this
        function returns a string.
      **kwargs: Optional arguments forwarded to csv.writer().
    Returns:
      A string, the rendered table, or None, if a file object is provided
      to write to.
    """
    output_file = file or io.StringIO()

    writer = csv.writer(output_file, **kwargs)
    if table.header:
        writer.writerow(table.header)
    writer.writerows(table.body)

    if not file:
        return output_file.getvalue()

beancount.utils.table.table_to_html(table, classes=None, file=None)

将 Table 渲染为 HTML。

参数:
  • table – 一个 Table 实例。

  • classes – 一个字符串列表,指定表格的 CSS 类。

  • file – 用于写入的文件对象。如果未提供对象,则此函数返回一个字符串。

返回:
  • 一个字符串,即渲染后的表格;或者如果提供了文件对象用于写入,则返回 None。

源代码位于 beancount/utils/table.py
def table_to_html(table, classes=None, file=None):
    """Render a Table to HTML.

    Args:
      table: An instance of a Table.
      classes: A list of string, CSS classes to set on the table.
      file: A file object to write to. If no object is provided, this
        function returns a string.
    Returns:
      A string, the rendered table, or None, if a file object is provided
      to write to.
    """
    # Initialize file.
    oss = io.StringIO() if file is None else file
    oss.write('<table class="{}">\n'.format(' '.join(classes or [])))

    # Render header.
    if table.header:
        oss.write('  <thead>\n')
        oss.write('    <tr>\n')
        for header in table.header:
            oss.write('      <th>{}</th>\n'.format(header))
        oss.write('    </tr>\n')
        oss.write('  </thead>\n')

    # Render body.
    oss.write('  <tbody>\n')
    for row in table.body:
        oss.write('    <tr>\n')
        for cell in row:
            oss.write('      <td>{}</td>\n'.format(cell))
        oss.write('    </tr>\n')
    oss.write('  </tbody>\n')

    # Render footer.
    oss.write('</table>\n')
    if file is None:
        return oss.getvalue()

beancount.utils.table.table_to_text(table, column_interspace=' ', formats=None)

将 Table 渲染为 ASCII 文本。

参数:
  • table – 一个 Table 实例。

  • column_interspace – 用于在列之间作为分隔符显示的字符串。

  • formats – 一个可选字典,键为列名,值为格式字符,该字符将插入到格式字符串中,格式如:{:}。键为 '' 时提供默认值,例如:(... formats={'': '>'})。

返回:
  • 一个字符串,即渲染后的文本表格。

源代码位于 beancount/utils/table.py
def table_to_text(table,
                  column_interspace=" ",
                  formats=None):
    """Render a Table to ASCII text.

    Args:
      table: An instance of a Table.
      column_interspace: A string to render between the columns as spacer.
      formats: An optional dict of column name to a format character that gets
        inserted in a format string specified, like this (where '<char>' is):
        {:<char><width>}. A key of '*' will provide a default value, like
        this, for example: (... formats={'*': '>'}).
    Returns:
      A string, the rendered text table.
    """
    column_widths = compute_table_widths(itertools.chain([table.header],
                                                         table.body))

    # Insert column format chars and compute line formatting string.
    column_formats = []
    if formats:
        default_format = formats.get('*', None)
    for column, width in zip(table.columns, column_widths):
        if column and formats:
            format_ = formats.get(column, default_format)
            if format_:
                column_formats.append("{{:{}{:d}}}".format(format_, width))
            else:
                column_formats.append("{{:{:d}}}".format(width))
        else:
            column_formats.append("{{:{:d}}}".format(width))

    line_format = column_interspace.join(column_formats) + "\n"
    separator = line_format.format(*[('-' * width) for width in column_widths])

    # Render the header.
    oss = io.StringIO()
    if table.header:
        oss.write(line_format.format(*table.header))

    # Render the body.
    oss.write(separator)
    for row in table.body:
        oss.write(line_format.format(*row))
    oss.write(separator)

    return oss.getvalue()

beancount.utils.test_utils

用于测试脚本的辅助工具。

beancount.utils.test_utils.ClickTestCase (TestCase)

命令行程序测试用例的基类。

beancount.utils.test_utils.RCall (tuple)

RCall(args, kwargs, return_value)

beancount.utils.test_utils.RCall.__getnewargs__(self) 特殊

将自身返回为一个普通元组。供 copy 和 pickle 使用。

源代码位于 beancount/utils/test_utils.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)

beancount.utils.test_utils.RCall.__new__(_cls, args, kwargs, return_value) 特殊 静态方法

创建 RCall(args, kwargs, return_value) 的新实例

beancount.utils.test_utils.RCall.__replace__(/, self, **kwds) 特殊

返回一个新的 RCall 对象,用指定的新值替换字段

源代码位于 beancount/utils/test_utils.py
def _replace(self, /, **kwds):
    result = self._make(_map(kwds.pop, field_names, self))
    if kwds:
        raise TypeError(f'Got unexpected field names: {list(kwds)!r}')
    return result

beancount.utils.test_utils.RCall.__repr__(self) 特殊

返回一个格式良好的表示字符串

源代码位于 beancount/utils/test_utils.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self

beancount.utils.test_utils.TestCase (TestCase)

beancount.utils.test_utils.TestCase.assertLines(self, text1, text2, message=None)

比较 text1 和 text2 的行,忽略空白字符。

参数:
  • text1 – 一个字符串,预期的文本。

  • text2 – 一个字符串,实际的文本。

  • message – 断言失败时可选的字符串消息。

异常:
  • AssertionError – 如果断言失败。

源代码位于 beancount/utils/test_utils.py
def assertLines(self, text1, text2, message=None):
    """Compare the lines of text1 and text2, ignoring whitespace.

    Args:
      text1: A string, the expected text.
      text2: A string, the actual text.
      message: An optional string message in case the assertion fails.
    Raises:
      AssertionError: If the exception fails.
    """
    clean_text1 = textwrap.dedent(text1.strip())
    clean_text2 = textwrap.dedent(text2.strip())
    lines1 = [line.strip() for line in clean_text1.splitlines()]
    lines2 = [line.strip() for line in clean_text2.splitlines()]

    # Compress all space longer than 4 spaces to exactly 4.
    # This affords us to be even looser.
    lines1 = [re.sub('    [ \t]*', '    ', line) for line in lines1]
    lines2 = [re.sub('    [ \t]*', '    ', line) for line in lines2]
    self.assertEqual(lines1, lines2, message)

beancount.utils.test_utils.TestCase.assertOutput(self, expected_text)

期望输出到标准输出的文本。

参数:
  • expected_text – 一个字符串,表示应被输出到标准输出的文本。

异常:
  • AssertionError – 如果文本不匹配时抛出。

源代码位于 beancount/utils/test_utils.py
@contextlib.contextmanager
def assertOutput(self, expected_text):
    """Expect text printed to stdout.

    Args:
      expected_text: A string, the text that should have been printed to stdout.
    Raises:
      AssertionError: If the text differs.
    """
    with capture() as oss:
        yield oss
    self.assertLines(textwrap.dedent(expected_text), oss.getvalue())

beancount.utils.test_utils.TmpFilesTestBase (TestCase)

一个测试工具基类,用于创建和清理目录层次结构。此便利功能适用于测试处理文件的函数,例如文档测试或账户遍历。

beancount.utils.test_utils.TmpFilesTestBase.create_file_hierarchy(test_files, subdir='root') 静态方法

一个测试工具,用于创建文件层次结构。

参数:
  • test_files – 一个字符串列表,表示相对于临时根目录的相对文件名。如果文件名以“/”结尾,则创建目录;否则创建普通文件。

  • subdir – 一个字符串,表示在临时目录下创建层次结构的子目录名称。

返回:
  • 一对字符串,分别是临时目录和其下承载树根的子目录。

源代码位于 beancount/utils/test_utils.py
@staticmethod
def create_file_hierarchy(test_files, subdir='root'):
    """A test utility that creates a hierarchy of files.

    Args:
      test_files: A list of strings, relative filenames to a temporary root
        directory. If the filename ends with a '/', we create a directory;
        otherwise, we create a regular file.
      subdir: A string, the subdirectory name under the temporary directory
        location, to create the hierarchy under.
    Returns:
      A pair of strings, the temporary directory, and the subdirectory under
        that which hosts the root of the tree.
    """
    tempdir = tempfile.mkdtemp(prefix="beancount-test-tmpdir.")
    root = path.join(tempdir, subdir)
    for filename in test_files:
        abs_filename = path.join(tempdir, filename)
        if filename.endswith('/'):
            os.makedirs(abs_filename)
        else:
            parent_dir = path.dirname(abs_filename)
            if not path.exists(parent_dir):
                os.makedirs(parent_dir)
            with open(abs_filename, 'w'): pass
    return tempdir, root

beancount.utils.test_utils.TmpFilesTestBase.setUp(self)

在执行测试前设置测试环境的钩子方法。

源代码位于 beancount/utils/test_utils.py
def setUp(self):
    self.tempdir, self.root = self.create_file_hierarchy(self.TEST_DOCUMENTS)

beancount.utils.test_utils.TmpFilesTestBase.tearDown(self)

在测试完成后拆除测试环境的钩子方法。

源代码位于 beancount/utils/test_utils.py
def tearDown(self):
    shutil.rmtree(self.tempdir, ignore_errors=True)

beancount.utils.test_utils.capture(*attributes)

一个上下文管理器,用于捕获输出到标准输出的内容。

参数:
  • *attributes – 一个字符串元组,表示要被替换为 StringIO 实例的 sys 属性名称。

返回:一个 StringIO 字符串累加器。

源代码位于 beancount/utils/test_utils.py
def capture(*attributes):
    """A context manager that captures what's printed to stdout.

    Args:
      *attributes: A tuple of strings, the name of the sys attributes to override
        with StringIO instances.
    Yields:
      A StringIO string accumulator.
    """
    if not attributes:
        attributes = 'stdout'
    elif len(attributes) == 1:
        attributes = attributes[0]
    return patch(sys, attributes, io.StringIO)

beancount.utils.test_utils.create_temporary_files(root, contents_map)

在 'root' 下创建多个临时文件。

此例程用于初始化临时目录下多个文件的内容。

参数:
  • root – 一个字符串,表示创建文件的目录名称。

  • contents_map – 一个字典,键为相对文件名,值为对应内容。内容字符串将自动去缩进以方便使用。此外,内容中的字符串“ROOT”将自动替换为根目录名称。

源代码位于 beancount/utils/test_utils.py
def create_temporary_files(root, contents_map):
    """Create a number of temporary files under 'root'.

    This routine is used to initialize the contents of multiple files under a
    temporary directory.

    Args:
      root: A string, the name of the directory under which to create the files.
      contents_map: A dict of relative filenames to their contents. The content
        strings will be automatically dedented for convenience. In addition, the
        string 'ROOT' in the contents will be automatically replaced by the root
        directory name.
    """
    os.makedirs(root, exist_ok=True)
    for relative_filename, contents in contents_map.items():
        assert not path.isabs(relative_filename)
        filename = path.join(root, relative_filename)
        os.makedirs(path.dirname(filename), exist_ok=True)

        clean_contents = textwrap.dedent(contents.replace('{root}', root))
        with open(filename, 'w') as f:
            f.write(clean_contents)

beancount.utils.test_utils.docfile(function, **kwargs)

一个装饰器,将函数的文档字符串写入临时文件,并以该临时文件名调用被装饰的函数。这在编写测试时非常有用。

参数:
  • function – 要装饰的函数。

返回:
  • 被装饰的函数。

源代码位于 beancount/utils/test_utils.py
def docfile(function, **kwargs):
    """A decorator that write the function's docstring to a temporary file
    and calls the decorated function with the temporary filename.  This is
    useful for writing tests.

    Args:
      function: A function to decorate.
    Returns:
      The decorated function.
    """
    contents = kwargs.pop('contents', None)

    @functools.wraps(function)
    def new_function(self):
        allowed = ('buffering', 'encoding', 'newline', 'dir', 'prefix', 'suffix')
        if any(key not in allowed for key in kwargs):
            raise ValueError("Invalid kwarg to docfile_extra")
        with tempfile.NamedTemporaryFile('w', **kwargs) as file:
            text = contents or function.__doc__
            file.write(textwrap.dedent(text))
            file.flush()
            return function(self, file.name)
    new_function.__doc__ = None
    return new_function

beancount.utils.test_utils.docfile_extra(**kwargs)

一个与 @docfile 相同的装饰器,但还接受用于临时文件的关键字参数。参数示例:buffering、encoding、newline、dir、prefix 和 suffix。

返回:
  • docfile

源代码位于 beancount/utils/test_utils.py
def docfile_extra(**kwargs):
    """
    A decorator identical to @docfile,
    but it also takes kwargs for the temporary file,
    Kwargs:
      e.g. buffering, encoding, newline, dir, prefix, and suffix.
    Returns:
      docfile
    """
    return functools.partial(docfile, **kwargs)

beancount.utils.test_utils.environ(varname, newvalue)

一个上下文管理器,用于推送 varname 的值并在之后恢复它。

参数:
  • varname – 一个字符串,表示环境变量的名称。

  • newvalue – 一个字符串,表示期望的值。

源代码位于 beancount/utils/test_utils.py
@contextlib.contextmanager
def environ(varname, newvalue):
    """A context manager which pushes varname's value and restores it later.

    Args:
      varname: A string, the environ variable name.
      newvalue: A string, the desired value.
    """
    oldvalue = os.environ.get(varname, None)
    os.environ[varname] = newvalue
    yield
    if oldvalue is not None:
        os.environ[varname] = oldvalue
    else:
        del os.environ[varname]

beancount.utils.test_utils.find_python_lib()

返回 Python 库根目录的路径。

返回:
  • 一个字符串,表示根目录。

源代码位于 beancount/utils/test_utils.py
def find_python_lib():
    """Return the path to the root of the Python libraries.

    Returns:
      A string, the root directory.
    """
    return path.dirname(path.dirname(path.dirname(__file__)))

beancount.utils.test_utils.find_repository_root(filename=None)

返回代码仓库根目录的路径。

参数:
  • filename – 一个字符串,表示仓库内某个文件的名称。

返回:
  • 一个字符串,表示根目录。

源代码位于 beancount/utils/test_utils.py
def find_repository_root(filename=None):
    """Return the path to the repository root.

    Args:
      filename: A string, the name of a file within the repository.
    Returns:
      A string, the root directory.
    """
    if filename is None:
        filename = __file__

    # Support root directory under Bazel.
    match = re.match(r"(.*\.runfiles/beancount)/", filename)
    if match:
        return match.group(1)

    while not path.exists(path.join(filename, 'pyproject.toml')):
        prev_filename = filename
        filename = path.dirname(filename)
        if prev_filename == filename:
            raise ValueError("Failed to find the root directory.")
    return filename

beancount.utils.test_utils.make_failing_importer(*removed_module_names)

创建一个在导入某些模块时抛出 ImportError 的导入器。

使用方式如下:

@mock.patch('builtins.import', make_failing_importer('setuptools')) def test_...

参数:
  • removed_module_name – 应该抛出异常的模块名称。

返回:
  • 一个装饰器,用于修饰测试函数。

源代码位于 beancount/utils/test_utils.py
def make_failing_importer(*removed_module_names):
    """Make an importer that raise an ImportError for some modules.

    Use it like this:

      @mock.patch('builtins.__import__', make_failing_importer('setuptools'))
      def test_...

    Args:
      removed_module_name: The name of the module import that should raise an exception.
    Returns:
      A decorated test decorator.
    """
    def failing_import(name, *args, **kwargs):
        if name in removed_module_names:
            raise ImportError("Could not import {}".format(name))
        return builtins.__import__(name, *args, **kwargs)
    return failing_import

beancount.utils.test_utils.nottest(func)

使给定函数不可被测试。

源代码位于 beancount/utils/test_utils.py
def nottest(func):
    "Make the given function not testable."
    func.__test__ = False
    return func

beancount.utils.test_utils.patch(obj, attributes, replacement_type)

一个上下文管理器,临时替换对象的属性。

在 'attributes' 中的所有属性都会被保存,并替换为 'replacement_type' 类型的新实例。

参数:
  • obj – 需要被修改的对象。

  • attributes – 一个字符串或字符串序列,表示要替换的属性名称。

  • replacement_type – 用于创建替换对象的可调用对象。

返回:一个由 'replacement_type' 实例组成的列表。

源代码位于 beancount/utils/test_utils.py
@contextlib.contextmanager
def patch(obj, attributes, replacement_type):
    """A context manager that temporarily patches an object's attributes.

    All attributes in 'attributes' are saved and replaced by new instances
    of type 'replacement_type'.

    Args:
      obj: The object to patch up.
      attributes: A string or a sequence of strings, the names of attributes to replace.
      replacement_type: A callable to build replacement objects.
    Yields:
      An instance of a list of sequences of 'replacement_type'.
    """
    single = isinstance(attributes, str)
    if single:
        attributes = [attributes]

    saved = []
    replacements = []
    for attribute in attributes:
        replacement = replacement_type()
        replacements.append(replacement)
        saved.append(getattr(obj, attribute))
        setattr(obj, attribute, replacement)

    yield replacements[0] if single else replacements

    for attribute, saved_attr in zip(attributes, saved):
        setattr(obj, attribute, saved_attr)

beancount.utils.test_utils.record(fun)

装饰函数,以拦截并记录所有调用及其返回值。

参数:
  • fun – 需要被装饰的可调用对象。

返回:
  • 一个带有 .calls 属性的包装函数,该属性是一个 RCall 实例列表。

源代码位于 beancount/utils/test_utils.py
def record(fun):
    """Decorates the function to intercept and record all calls and return values.

    Args:
      fun: A callable to be decorated.
    Returns:
      A wrapper function with a .calls attribute, a list of RCall instances.
    """
    @functools.wraps(fun)
    def wrapped(*args, **kw):
        return_value = fun(*args, **kw)
        wrapped.calls.append(RCall(args, kw, return_value))
        return return_value
    wrapped.calls = []
    return wrapped

beancount.utils.test_utils.search_words(words, line)

在一行中搜索一组单词。

参数:
  • words – 一个字符串列表,表示要查找的单词,或一个用空格分隔的字符串。

  • line – 一个字符串,表示要搜索的行。

返回:
  • 一个 MatchObject 对象,或 None。

源代码位于 beancount/utils/test_utils.py
def search_words(words, line):
    """Search for a sequence of words in a line.

    Args:
      words: A list of strings, the words to look for, or a space-separated string.
      line: A string, the line to search into.
    Returns:
      A MatchObject, or None.
    """
    if isinstance(words, str):
        words = words.split()
    return re.search('.*'.join(r'\b{}\b'.format(word) for word in words), line)

beancount.utils.test_utils.skipIfRaises(*exc_types)

一个上下文管理器(或装饰器),当抛出异常时跳过测试。

yield:无,供你执行函数代码。

异常:
  • SkipTest – 如果测试抛出了预期的异常。

源代码位于 beancount/utils/test_utils.py
@contextlib.contextmanager
def skipIfRaises(*exc_types):
    """A context manager (or decorator) that skips a test if an exception is raised.

    Args:
      exc_type
    Yields:
      Nothing, for you to execute the function code.
    Raises:
      SkipTest: if the test raised the expected exception.
    """
    try:
        yield
    except exc_types as exception:
        raise unittest.SkipTest(exception)

beancount.utils.test_utils.subprocess_env()

返回一个用于运行子进程的环境变量字典。

返回:
  • 一个字符串,表示根目录。

源代码位于 beancount/utils/test_utils.py
def subprocess_env():
    """Return a dict to use as environment for running subprocesses.

    Returns:
      A string, the root directory.
    """
    # Ensure we have locations to invoke our Python executable and our
    # runnable binaries in the test environment to run subprocesses.
    binpath = ':'.join([
        path.dirname(sys.executable),
        path.join(find_repository_root(__file__), 'bin'),
        os.environ.get('PATH', '').strip(':')]).strip(':')
    return {'PATH': binpath,
            'PYTHONPATH': find_python_lib()}

beancount.utils.test_utils.tempdir(delete=True, **kw)

一个上下文管理器,创建一个临时目录,并在完成后无条件删除其内容。

参数:
  • delete – 布尔值,表示运行后是否删除该目录。

  • **kw – 传递给 mkdtemp 的关键字参数。

yield:一个字符串,表示创建的临时目录的名称。

源代码位于 beancount/utils/test_utils.py
@contextlib.contextmanager
def tempdir(delete=True, **kw):
    """A context manager that creates a temporary directory and deletes its
    contents unconditionally once done.

    Args:
      delete: A boolean, true if we want to delete the directory after running.
      **kw: Keyword arguments for mkdtemp.
    Yields:
      A string, the name of the temporary directory created.
    """
    tempdir = tempfile.mkdtemp(prefix="beancount-test-tmpdir.", **kw)
    try:
        yield tempdir
    finally:
        if delete:
            shutil.rmtree(tempdir, ignore_errors=True)