Package lxml :: Package tests :: Module test_threading
[hide private]
[frames] | no frames]

Source Code for Module lxml.tests.test_threading

  1  # -*- coding: utf-8 -*- 
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  from __future__ import absolute_import 
  8   
  9  import re 
 10  import sys 
 11  import unittest 
 12  import threading 
 13   
 14  from .common_imports import etree, HelperTestCase, BytesIO, _bytes 
 15   
 16  try: 
 17      from Queue import Queue 
 18  except ImportError: 
 19      from queue import Queue # Py3 
 20   
 21   
22 -class ThreadingTestCase(HelperTestCase):
23 """Threading tests""" 24 etree = etree 25
26 - def _run_thread(self, func):
27 thread = threading.Thread(target=func) 28 thread.start() 29 thread.join()
30
31 - def _run_threads(self, count, func, main_func=None):
32 sync = threading.Event() 33 lock = threading.Lock() 34 counter = dict(started=0, finished=0, failed=0) 35 36 def sync_start(func): 37 with lock: 38 started = counter['started'] + 1 39 counter['started'] = started 40 if started < count + (main_func is not None): 41 sync.wait(4) # wait until the other threads have started up 42 assert sync.is_set() 43 sync.set() # all waiting => go! 44 try: 45 func() 46 except: 47 with lock: 48 counter['failed'] += 1 49 raise 50 else: 51 with lock: 52 counter['finished'] += 1
53 54 threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 55 for thread in threads: 56 thread.start() 57 if main_func is not None: 58 sync_start(main_func) 59 for thread in threads: 60 thread.join() 61 62 self.assertEqual(0, counter['failed']) 63 self.assertEqual(counter['finished'], counter['started'])
64
65 - def test_subtree_copy_thread(self):
66 tostring = self.etree.tostring 67 XML = self.etree.XML 68 xml = _bytes("<root><threadtag/></root>") 69 main_root = XML(_bytes("<root/>")) 70 71 def run_thread(): 72 thread_root = XML(xml) 73 main_root.append(thread_root[0]) 74 del thread_root
75 76 self._run_thread(run_thread) 77 self.assertEqual(xml, tostring(main_root)) 78
79 - def test_main_xslt_in_thread(self):
80 XML = self.etree.XML 81 style = XML(_bytes('''\ 82 <xsl:stylesheet version="1.0" 83 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 84 <xsl:template match="*"> 85 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 86 </xsl:template> 87 </xsl:stylesheet>''')) 88 st = etree.XSLT(style) 89 90 result = [] 91 92 def run_thread(): 93 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 94 result.append( st(root) )
95 96 self._run_thread(run_thread) 97 self.assertEqual('''\ 98 <?xml version="1.0"?> 99 <foo><a>B</a></foo> 100 ''', 101 str(result[0])) 102
103 - def test_thread_xslt(self):
104 XML = self.etree.XML 105 tostring = self.etree.tostring 106 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 107 108 def run_thread(): 109 style = XML(_bytes('''\ 110 <xsl:stylesheet version="1.0" 111 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 112 <xsl:template match="*"> 113 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 114 </xsl:template> 115 </xsl:stylesheet>''')) 116 st = etree.XSLT(style) 117 root.append( st(root).getroot() )
118 119 self._run_thread(run_thread) 120 self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 121 tostring(root)) 122
123 - def test_thread_xslt_parsing_error_log(self):
124 style = self.parse('''\ 125 <xsl:stylesheet version="1.0" 126 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 127 <xsl:template match="tag" /> 128 <!-- extend time for parsing + transform --> 129 ''' + '\n'.join('<xsl:template match="tag%x" />' % i for i in range(200)) + ''' 130 <xsl:UnExpectedElement /> 131 </xsl:stylesheet>''') 132 self.assertRaises(etree.XSLTParseError, 133 etree.XSLT, style) 134 135 error_logs = [] 136 137 def run_thread(): 138 try: 139 etree.XSLT(style) 140 except etree.XSLTParseError as e: 141 error_logs.append(e.error_log) 142 else: 143 self.assertFalse(True, "XSLT parsing should have failed but didn't")
144 145 self._run_threads(16, run_thread) 146 147 self.assertEqual(16, len(error_logs)) 148 last_log = None 149 for log in error_logs: 150 self.assertTrue(len(log)) 151 if last_log is not None: 152 self.assertEqual(len(last_log), len(log)) 153 self.assertTrue(len(log) >= 2, len(log)) 154 for error in log: 155 self.assertTrue(':ERROR:XSLT:' in str(error), str(error)) 156 self.assertTrue(any('UnExpectedElement' in str(error) for error in log), log) 157 last_log = log 158
159 - def test_thread_xslt_apply_error_log(self):
160 tree = self.parse('<tagFF/>') 161 style = self.parse('''\ 162 <xsl:stylesheet version="1.0" 163 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 164 <xsl:template name="tag0"> 165 <xsl:message terminate="yes">FAIL</xsl:message> 166 </xsl:template> 167 <!-- extend time for parsing + transform --> 168 ''' + '\n'.join('<xsl:template match="tag%X" name="tag%x"> <xsl:call-template name="tag%x" /> </xsl:template>' % (i, i, i-1) 169 for i in range(1, 256)) + ''' 170 </xsl:stylesheet>''') 171 self.assertRaises(etree.XSLTApplyError, 172 etree.XSLT(style), tree) 173 174 error_logs = [] 175 176 def run_thread(): 177 transform = etree.XSLT(style) 178 try: 179 transform(tree) 180 except etree.XSLTApplyError: 181 error_logs.append(transform.error_log) 182 else: 183 self.assertFalse(True, "XSLT parsing should have failed but didn't")
184 185 self._run_threads(16, run_thread) 186 187 self.assertEqual(16, len(error_logs)) 188 last_log = None 189 for log in error_logs: 190 self.assertTrue(len(log)) 191 if last_log is not None: 192 self.assertEqual(len(last_log), len(log)) 193 self.assertEqual(1, len(log)) 194 for error in log: 195 self.assertTrue(':ERROR:XSLT:' in str(error)) 196 last_log = log 197
198 - def test_thread_xslt_attr_replace(self):
199 # this is the only case in XSLT where the result tree can be 200 # modified in-place 201 XML = self.etree.XML 202 tostring = self.etree.tostring 203 style = self.etree.XSLT(XML(_bytes('''\ 204 <xsl:stylesheet version="1.0" 205 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 206 <xsl:template match="*"> 207 <root class="abc"> 208 <xsl:copy-of select="@class" /> 209 <xsl:attribute name="class">xyz</xsl:attribute> 210 </root> 211 </xsl:template> 212 </xsl:stylesheet>'''))) 213 214 result = [] 215 def run_thread(): 216 root = XML(_bytes('<ROOT class="ABC" />')) 217 result.append( style(root).getroot() )
218 219 self._run_thread(run_thread) 220 self.assertEqual(_bytes('<root class="xyz"/>'), 221 tostring(result[0])) 222
223 - def test_thread_create_xslt(self):
224 XML = self.etree.XML 225 tostring = self.etree.tostring 226 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 227 228 stylesheets = [] 229 230 def run_thread(): 231 style = XML(_bytes('''\ 232 <xsl:stylesheet 233 xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 234 version="1.0"> 235 <xsl:output method="xml" /> 236 <xsl:template match="/"> 237 <div id="test"> 238 <xsl:apply-templates/> 239 </div> 240 </xsl:template> 241 </xsl:stylesheet>''')) 242 stylesheets.append( etree.XSLT(style) )
243 244 self._run_thread(run_thread) 245 246 st = stylesheets[0] 247 result = tostring( st(root) ) 248 249 self.assertEqual(_bytes('<div id="test">BC</div>'), 250 result) 251
252 - def test_thread_error_log(self):
253 XML = self.etree.XML 254 expected_error = [self.etree.ErrorTypes.ERR_TAG_NAME_MISMATCH] 255 children = "<a>test</a>" * 100 256 257 def parse_error_test(thread_no): 258 tag = "tag%d" % thread_no 259 xml = "<%s>%s</%s>" % (tag, children, tag.upper()) 260 parser = self.etree.XMLParser() 261 for _ in range(10): 262 errors = None 263 try: 264 XML(xml, parser) 265 except self.etree.ParseError: 266 e = sys.exc_info()[1] 267 errors = e.error_log.filter_types(expected_error) 268 self.assertTrue(errors, "Expected error not found") 269 for error in errors: 270 self.assertTrue( 271 tag in error.message and tag.upper() in error.message, 272 "%s and %s not found in '%s'" % ( 273 tag, tag.upper(), error.message))
274 275 self.etree.clear_error_log() 276 threads = [] 277 for thread_no in range(1, 10): 278 t = threading.Thread(target=parse_error_test, 279 args=(thread_no,)) 280 threads.append(t) 281 t.start() 282 283 parse_error_test(0) 284 285 for t in threads: 286 t.join() 287
288 - def test_thread_mix(self):
289 XML = self.etree.XML 290 Element = self.etree.Element 291 SubElement = self.etree.SubElement 292 tostring = self.etree.tostring 293 xml = _bytes('<a><b>B</b><c xmlns="test">C</c></a>') 294 root = XML(xml) 295 fragment = XML(_bytes("<other><tags/></other>")) 296 297 result = self.etree.Element("{myns}root", att = "someval") 298 299 def run_XML(): 300 thread_root = XML(xml) 301 result.append(thread_root[0]) 302 result.append(thread_root[-1])
303 304 def run_parse(): 305 thread_root = self.etree.parse(BytesIO(xml)).getroot() 306 result.append(thread_root[0]) 307 result.append(thread_root[-1]) 308 309 def run_move_main(): 310 result.append(fragment[0]) 311 312 def run_build(): 313 result.append( 314 Element("{myns}foo", attrib={'{test}attr':'val'})) 315 SubElement(result, "{otherns}tasty") 316 317 def run_xslt(): 318 style = XML(_bytes('''\ 319 <xsl:stylesheet version="1.0" 320 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 321 <xsl:template match="*"> 322 <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 323 </xsl:template> 324 </xsl:stylesheet>''')) 325 st = etree.XSLT(style) 326 result.append( st(root).getroot() ) 327 328 for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 329 tostring(result) 330 self._run_thread(test) 331 332 self.assertEqual( 333 _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 334 '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 335 '<a><foo>B</foo></a>' 336 '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 337 '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 338 tostring(result)) 339 340 def strip_first(): 341 root = Element("newroot") 342 root.append(result[0]) 343 344 while len(result): 345 self._run_thread(strip_first) 346 347 self.assertEqual( 348 _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 349 tostring(result)) 350
351 - def test_concurrent_attribute_names_in_dicts(self):
352 SubElement = self.etree.SubElement 353 names = list('abcdefghijklmnop') 354 runs_per_name = range(50) 355 result_matches = re.compile( 356 br'<thread_root>' 357 br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 358 br'</thread_root>').match 359 360 def testrun(): 361 for _ in range(3): 362 root = self.etree.Element('thread_root') 363 for name in names: 364 tag_name = name * 5 365 new = [] 366 for _ in runs_per_name: 367 el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 368 new.append(el) 369 for el in new: 370 el.set('thread_attr2_' + name, 'value2') 371 s = etree.tostring(root) 372 self.assertTrue(result_matches(s))
373 374 # first, run only in sub-threads 375 self._run_threads(10, testrun) 376 377 # then, additionally include the main thread (and its parent dict) 378 self._run_threads(10, testrun, main_func=testrun) 379
380 - def test_concurrent_proxies(self):
381 XML = self.etree.XML 382 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 383 child_count = len(root) 384 def testrun(): 385 for i in range(10000): 386 el = root[i%child_count] 387 del el
388 self._run_threads(10, testrun) 389
390 - def test_concurrent_class_lookup(self):
391 XML = self.etree.XML 392 393 class TestElement(etree.ElementBase): 394 pass
395 396 class MyLookup(etree.CustomElementClassLookup): 397 repeat = range(100) 398 def lookup(self, t, d, ns, name): 399 count = 0 400 for i in self.repeat: 401 # allow other threads to run 402 count += 1 403 return TestElement 404 405 parser = self.etree.XMLParser() 406 parser.set_element_class_lookup(MyLookup()) 407 408 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 409 parser) 410 411 child_count = len(root) 412 def testrun(): 413 for i in range(1000): 414 el = root[i%child_count] 415 del el 416 self._run_threads(10, testrun) 417 418
419 -class ThreadPipelineTestCase(HelperTestCase):
420 """Threading tests based on a thread worker pipeline. 421 """ 422 etree = etree 423 item_count = 40 424
425 - class Worker(threading.Thread):
426 - def __init__(self, in_queue, in_count, **kwargs):
427 threading.Thread.__init__(self) 428 self.in_queue = in_queue 429 self.in_count = in_count 430 self.out_queue = Queue(in_count) 431 self.__dict__.update(kwargs)
432
433 - def run(self):
434 get, put = self.in_queue.get, self.out_queue.put 435 handle = self.handle 436 for _ in range(self.in_count): 437 put(handle(get()))
438
439 - def handle(self, data):
440 raise NotImplementedError()
441
442 - class ParseWorker(Worker):
443 - def handle(self, xml, _fromstring=etree.fromstring):
444 return _fromstring(xml)
445
446 - class RotateWorker(Worker):
447 - def handle(self, element):
448 first = element[0] 449 element[:] = element[1:] 450 element.append(first) 451 return element
452
453 - class ReverseWorker(Worker):
454 - def handle(self, element):
455 element[:] = element[::-1] 456 return element
457
458 - class ParseAndExtendWorker(Worker):
459 - def handle(self, element, _fromstring=etree.fromstring):
460 element.extend(_fromstring(self.xml)) 461 return element
462
463 - class ParseAndInjectWorker(Worker):
464 - def handle(self, element, _fromstring=etree.fromstring):
465 root = _fromstring(self.xml) 466 root.extend(element) 467 return root
468
469 - class Validate(Worker):
470 - def handle(self, element):
473
474 - class SerialiseWorker(Worker):
475 - def handle(self, element):
476 return etree.tostring(element)
477 478 xml = (b'''\ 479 <!DOCTYPE threadtest [ 480 <!ELEMENT threadtest (thread-tag1,thread-tag2)+> 481 <!ATTLIST threadtest 482 version CDATA "1.0" 483 > 484 <!ELEMENT thread-tag1 EMPTY> 485 <!ELEMENT thread-tag2 (div)> 486 <!ELEMENT div (threaded)> 487 <!ATTLIST div 488 huhu CDATA #IMPLIED 489 > 490 <!ELEMENT threaded EMPTY> 491 <!ATTLIST threaded 492 host CDATA #REQUIRED 493 > 494 ]> 495 <threadtest version="123"> 496 ''' + (b''' 497 <thread-tag1 /> 498 <thread-tag2> 499 <div huhu="true"> 500 <threaded host="here" /> 501 </div> 502 </thread-tag2> 503 ''') * 20 + b''' 504 </threadtest>''') 505
506 - def _build_pipeline(self, item_count, *classes, **kwargs):
507 in_queue = Queue(item_count) 508 start = last = classes[0](in_queue, item_count, **kwargs) 509 start.setDaemon(True) 510 for worker_class in classes[1:]: 511 last = worker_class(last.out_queue, item_count, **kwargs) 512 last.setDaemon(True) 513 last.start() 514 return in_queue, start, last
515
517 item_count = self.item_count 518 xml = self.xml.replace(b'thread', b'THREAD') # use fresh tag names 519 520 # build and start the pipeline 521 in_queue, start, last = self._build_pipeline( 522 item_count, 523 self.ParseWorker, 524 self.RotateWorker, 525 self.ReverseWorker, 526 self.ParseAndExtendWorker, 527 self.Validate, 528 self.ParseAndInjectWorker, 529 self.SerialiseWorker, 530 xml=xml) 531 532 # fill the queue 533 put = start.in_queue.put 534 for _ in range(item_count): 535 put(xml) 536 537 # start the first thread and thus everything 538 start.start() 539 # make sure the last thread has terminated 540 last.join(60) # time out after 60 seconds 541 self.assertEqual(item_count, last.out_queue.qsize()) 542 # read the results 543 get = last.out_queue.get 544 results = [get() for _ in range(item_count)] 545 546 comparison = results[0] 547 for i, result in enumerate(results[1:]): 548 self.assertEqual(comparison, result)
549
551 item_count = self.item_count 552 xml = self.xml.replace(b'thread', b'GLOBAL') # use fresh tag names 553 XML = self.etree.XML 554 # build and start the pipeline 555 in_queue, start, last = self._build_pipeline( 556 item_count, 557 self.RotateWorker, 558 self.ReverseWorker, 559 self.ParseAndExtendWorker, 560 self.Validate, 561 self.SerialiseWorker, 562 xml=xml) 563 564 # fill the queue 565 put = start.in_queue.put 566 for _ in range(item_count): 567 put(XML(xml)) 568 569 # start the first thread and thus everything 570 start.start() 571 # make sure the last thread has terminated 572 last.join(60) # time out after 90 seconds 573 self.assertEqual(item_count, last.out_queue.qsize()) 574 # read the results 575 get = last.out_queue.get 576 results = [get() for _ in range(item_count)] 577 578 comparison = results[0] 579 for i, result in enumerate(results[1:]): 580 self.assertEqual(comparison, result)
581 582
583 -def test_suite():
584 suite = unittest.TestSuite() 585 suite.addTests([unittest.makeSuite(ThreadingTestCase)]) 586 suite.addTests([unittest.makeSuite(ThreadPipelineTestCase)]) 587 return suite
588 589 if __name__ == '__main__': 590 print('to test use test.py %s' % __file__) 591