Source code for lxml.doctestcompare

"""
lxml-based doctest output comparison.

Note: normally, you should just import the `lxml.usedoctest` and
`lxml.html.usedoctest` modules from within a doctest, instead of this
one::

    >>> import lxml.usedoctest # for XML output

    >>> import lxml.html.usedoctest # for HTML output

To use this module directly, you must call ``lxmldoctest.install()``,
which will cause doctest to use this in all subsequent calls.

This changes the way output is checked and comparisons are made for
XML or HTML-like content.

XML or HTML content is noticed because the example starts with ``<``
(it's HTML if it starts with ``<html``).  You can also use the
``PARSE_HTML`` and ``PARSE_XML`` flags to force parsing.

Some rough wildcard-like things are allowed.  Whitespace is generally
ignored (except in attributes).  In text (attributes and text in the
body) you can use ``...`` as a wildcard.  In an example it also
matches any trailing tags in the element, though it does not match
leading tags.  You may create a tag ``<any>`` or include an ``any``
attribute in the tag.  An ``any`` tag matches any tag, while the
attribute matches any and all attributes.

When a match fails, the reformatted example and gotten text is
displayed (indented), and a rough diff-like output is given.  Anything
marked with ``+`` is in the output but wasn't supposed to be, and
similarly ``-`` means its in the example but wasn't in the output.

You can disable parsing on one line with ``# doctest:+NOPARSE_MARKUP``
"""

from lxml import etree
import sys
import re
import doctest
try:
    from html import escape as html_escape
except ImportError:
    from cgi import escape as html_escape

__all__ = ['PARSE_HTML', 'PARSE_XML', 'NOPARSE_MARKUP', 'LXMLOutputChecker',
           'LHTMLOutputChecker', 'install', 'temp_install']

try:
    _basestring = basestring
except NameError:
    _basestring = (str, bytes)

_IS_PYTHON_3 = sys.version_info[0] >= 3

PARSE_HTML = doctest.register_optionflag('PARSE_HTML')
PARSE_XML = doctest.register_optionflag('PARSE_XML')
NOPARSE_MARKUP = doctest.register_optionflag('NOPARSE_MARKUP')

OutputChecker = doctest.OutputChecker

[docs]def strip(v): if v is None: return None else: return v.strip()
[docs]def norm_whitespace(v): return _norm_whitespace_re.sub(' ', v)
_html_parser = etree.HTMLParser(recover=False, remove_blank_text=True)
[docs]def html_fromstring(html): return etree.fromstring(html, _html_parser)
# We use this to distinguish repr()s from elements: _repr_re = re.compile(r'^<[^>]+ (at|object) ') _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
[docs]class LXMLOutputChecker(OutputChecker): empty_tags = ( 'param', 'img', 'area', 'br', 'basefont', 'input', 'base', 'meta', 'link', 'col')
[docs] def get_default_parser(self): return etree.XML
[docs] def check_output(self, want, got, optionflags): alt_self = getattr(self, '_temp_override_self', None) if alt_self is not None: super_method = self._temp_call_super_check_output self = alt_self else: super_method = OutputChecker.check_output parser = self.get_parser(want, got, optionflags) if not parser: return super_method( self, want, got, optionflags) try: want_doc = parser(want) except etree.XMLSyntaxError: return False try: got_doc = parser(got) except etree.XMLSyntaxError: return False return self.compare_docs(want_doc, got_doc)
[docs] def get_parser(self, want, got, optionflags): parser = None if NOPARSE_MARKUP & optionflags: return None if PARSE_HTML & optionflags: parser = html_fromstring elif PARSE_XML & optionflags: parser = etree.XML elif (want.strip().lower().startswith('<html') and got.strip().startswith('<html')): parser = html_fromstring elif (self._looks_like_markup(want) and self._looks_like_markup(got)): parser = self.get_default_parser() return parser
[docs] def _looks_like_markup(self, s): s = s.strip() return (s.startswith('<') and not _repr_re.search(s))
[docs] def compare_docs(self, want, got): if not self.tag_compare(want.tag, got.tag): return False if not self.text_compare(want.text, got.text, True): return False if not self.text_compare(want.tail, got.tail, True): return False if 'any' not in want.attrib: want_keys = sorted(want.attrib.keys()) got_keys = sorted(got.attrib.keys()) if want_keys != got_keys: return False for key in want_keys: if not self.text_compare(want.attrib[key], got.attrib[key], False): return False if want.text != '...' or len(want): want_children = list(want) got_children = list(got) while want_children or got_children: if not want_children or not got_children: return False want_first = want_children.pop(0) got_first = got_children.pop(0) if not self.compare_docs(want_first, got_first): return False if not got_children and want_first.tail == '...': break return True
[docs] def text_compare(self, want, got, strip): want = want or '' got = got or '' if strip: want = norm_whitespace(want).strip() got = norm_whitespace(got).strip() want = '^%s$' % re.escape(want) want = want.replace(r'\.\.\.', '.*') if re.search(want, got): return True else: return False
[docs] def tag_compare(self, want, got): if want == 'any': return True if (not isinstance(want, _basestring) or not isinstance(got, _basestring)): return want == got want = want or '' got = got or '' if want.startswith('{...}'): # Ellipsis on the namespace return want.split('}')[-1] == got.split('}')[-1] else: return want == got
[docs] def output_difference(self, example, got, optionflags): want = example.want parser = self.get_parser(want, got, optionflags) errors = [] if parser is not None: try: want_doc = parser(want) except etree.XMLSyntaxError: e = sys.exc_info()[1] errors.append('In example: %s' % e) try: got_doc = parser(got) except etree.XMLSyntaxError: e = sys.exc_info()[1] errors.append('In actual output: %s' % e) if parser is None or errors: value = OutputChecker.output_difference( self, example, got, optionflags) if errors: errors.append(value) return '\n'.join(errors) else: return value html = parser is html_fromstring diff_parts = ['Expected:', self.format_doc(want_doc, html, 2), 'Got:', self.format_doc(got_doc, html, 2), 'Diff:', self.collect_diff(want_doc, got_doc, html, 2)] return '\n'.join(diff_parts)
[docs] def html_empty_tag(self, el, html=True): if not html: return False if el.tag not in self.empty_tags: return False if el.text or len(el): # This shouldn't happen (contents in an empty tag) return False return True
[docs] def format_doc(self, doc, html, indent, prefix=''): parts = [] if not len(doc): # No children... parts.append(' '*indent) parts.append(prefix) parts.append(self.format_tag(doc)) if not self.html_empty_tag(doc, html): if strip(doc.text): parts.append(self.format_text(doc.text)) parts.append(self.format_end_tag(doc)) if strip(doc.tail): parts.append(self.format_text(doc.tail)) parts.append('\n') return ''.join(parts) parts.append(' '*indent) parts.append(prefix) parts.append(self.format_tag(doc)) if not self.html_empty_tag(doc, html): parts.append('\n') if strip(doc.text): parts.append(' '*indent) parts.append(self.format_text(doc.text)) parts.append('\n') for el in doc: parts.append(self.format_doc(el, html, indent+2)) parts.append(' '*indent) parts.append(self.format_end_tag(doc)) parts.append('\n') if strip(doc.tail): parts.append(' '*indent) parts.append(self.format_text(doc.tail)) parts.append('\n') return ''.join(parts)
[docs] def format_text(self, text, strip=True): if text is None: return '' if strip: text = text.strip() return html_escape(text, 1)
[docs] def format_tag(self, el): attrs = [] if isinstance(el, etree.CommentBase): # FIXME: probably PIs should be handled specially too? return '<!--' for name, value in sorted(el.attrib.items()): attrs.append('%s="%s"' % (name, self.format_text(value, False))) if not attrs: return '<%s>' % el.tag return '<%s %s>' % (el.tag, ' '.join(attrs))
[docs] def format_end_tag(self, el): if isinstance(el, etree.CommentBase): # FIXME: probably PIs should be handled specially too? return '-->' return '</%s>' % el.tag
[docs] def collect_diff(self, want, got, html, indent): parts = [] if not len(want) and not len(got): parts.append(' '*indent) parts.append(self.collect_diff_tag(want, got)) if not self.html_empty_tag(got, html): parts.append(self.collect_diff_text(want.text, got.text)) parts.append(self.collect_diff_end_tag(want, got)) parts.append(self.collect_diff_text(want.tail, got.tail)) parts.append('\n') return ''.join(parts) parts.append(' '*indent) parts.append(self.collect_diff_tag(want, got)) parts.append('\n') if strip(want.text) or strip(got.text): parts.append(' '*indent) parts.append(self.collect_diff_text(want.text, got.text)) parts.append('\n') want_children = list(want) got_children = list(got) while want_children or got_children: if not want_children: parts.append(self.format_doc(got_children.pop(0), html, indent+2, '+')) continue if not got_children: parts.append(self.format_doc(want_children.pop(0), html, indent+2, '-')) continue parts.append(self.collect_diff( want_children.pop(0), got_children.pop(0), html, indent+2)) parts.append(' '*indent) parts.append(self.collect_diff_end_tag(want, got)) parts.append('\n') if strip(want.tail) or strip(got.tail): parts.append(' '*indent) parts.append(self.collect_diff_text(want.tail, got.tail)) parts.append('\n') return ''.join(parts)
[docs] def collect_diff_tag(self, want, got): if not self.tag_compare(want.tag, got.tag): tag = '%s (got: %s)' % (want.tag, got.tag) else: tag = got.tag attrs = [] any = want.tag == 'any' or 'any' in want.attrib for name, value in sorted(got.attrib.items()): if name not in want.attrib and not any: attrs.append('+%s="%s"' % (name, self.format_text(value, False))) else: if name in want.attrib: text = self.collect_diff_text(want.attrib[name], value, False) else: text = self.format_text(value, False) attrs.append('%s="%s"' % (name, text)) if not any: for name, value in sorted(want.attrib.items()): if name in got.attrib: continue attrs.append('-%s="%s"' % (name, self.format_text(value, False))) if attrs: tag = '<%s %s>' % (tag, ' '.join(attrs)) else: tag = '<%s>' % tag return tag
[docs] def collect_diff_end_tag(self, want, got): if want.tag != got.tag: tag = '%s (got: %s)' % (want.tag, got.tag) else: tag = got.tag return '</%s>' % tag
[docs] def collect_diff_text(self, want, got, strip=True): if self.text_compare(want, got, strip): if not got: return '' return self.format_text(got, strip) text = '%s (got: %s)' % (want, got) return self.format_text(text, strip)
[docs]class LHTMLOutputChecker(LXMLOutputChecker):
[docs] def get_default_parser(self): return html_fromstring
[docs]def install(html=False): """ Install doctestcompare for all future doctests. If html is true, then by default the HTML parser will be used; otherwise the XML parser is used. """ if html: doctest.OutputChecker = LHTMLOutputChecker else: doctest.OutputChecker = LXMLOutputChecker
[docs]def temp_install(html=False, del_module=None): """ Use this *inside* a doctest to enable this checker for this doctest only. If html is true, then by default the HTML parser will be used; otherwise the XML parser is used. """ if html: Checker = LHTMLOutputChecker else: Checker = LXMLOutputChecker frame = _find_doctest_frame() dt_self = frame.f_locals['self'] checker = Checker() old_checker = dt_self._checker dt_self._checker = checker # The unfortunate thing is that there is a local variable 'check' # in the function that runs the doctests, that is a bound method # into the output checker. We have to update that. We can't # modify the frame, so we have to modify the object in place. The # only way to do this is to actually change the func_code # attribute of the method. We change it, and then wait for # __record_outcome to be run, which signals the end of the __run # method, at which point we restore the previous check_output # implementation. if _IS_PYTHON_3: check_func = frame.f_locals['check'].__func__ checker_check_func = checker.check_output.__func__ else: check_func = frame.f_locals['check'].im_func checker_check_func = checker.check_output.im_func # Because we can't patch up func_globals, this is the only global # in check_output that we care about: doctest.etree = etree _RestoreChecker(dt_self, old_checker, checker, check_func, checker_check_func, del_module)
[docs]class _RestoreChecker(object): def __init__(self, dt_self, old_checker, new_checker, check_func, clone_func, del_module): self.dt_self = dt_self self.checker = old_checker self.checker._temp_call_super_check_output = self.call_super self.checker._temp_override_self = new_checker self.check_func = check_func self.clone_func = clone_func self.del_module = del_module self.install_clone() self.install_dt_self()
[docs] def install_clone(self): if _IS_PYTHON_3: self.func_code = self.check_func.__code__ self.func_globals = self.check_func.__globals__ self.check_func.__code__ = self.clone_func.__code__ else: self.func_code = self.check_func.func_code self.func_globals = self.check_func.func_globals self.check_func.func_code = self.clone_func.func_code
[docs] def uninstall_clone(self): if _IS_PYTHON_3: self.check_func.__code__ = self.func_code else: self.check_func.func_code = self.func_code
[docs] def install_dt_self(self): self.prev_func = self.dt_self._DocTestRunner__record_outcome self.dt_self._DocTestRunner__record_outcome = self
[docs] def uninstall_dt_self(self): self.dt_self._DocTestRunner__record_outcome = self.prev_func
[docs] def uninstall_module(self): if self.del_module: import sys del sys.modules[self.del_module] if '.' in self.del_module: package, module = self.del_module.rsplit('.', 1) package_mod = sys.modules[package] delattr(package_mod, module)
def __call__(self, *args, **kw): self.uninstall_clone() self.uninstall_dt_self() del self.checker._temp_override_self del self.checker._temp_call_super_check_output result = self.prev_func(*args, **kw) self.uninstall_module() return result
[docs] def call_super(self, *args, **kw): self.uninstall_clone() try: return self.check_func(*args, **kw) finally: self.install_clone()
[docs]def _find_doctest_frame(): import sys frame = sys._getframe(1) while frame: l = frame.f_locals if 'BOOM' in l: # Sign of doctest return frame frame = frame.f_back raise LookupError( "Could not find doctest (only use this function *inside* a doctest)")
__test__ = { 'basic': ''' >>> temp_install() >>> print """<xml a="1" b="2">stuff</xml>""" <xml b="2" a="1">...</xml> >>> print """<xml xmlns="http://example.com"><tag attr="bar" /></xml>""" <xml xmlns="..."> <tag attr="..." /> </xml> >>> print """<xml>blahblahblah<foo /></xml>""" # doctest: +NOPARSE_MARKUP, +ELLIPSIS <xml>...foo /></xml> '''} if __name__ == '__main__': import doctest doctest.testmod()