Package lxml :: Package tests :: Module test_threading
[hide private]
[frames] | no frames]

Source Code for Module lxml.tests.test_threading

  1  # -*- coding: utf-8 -*- 
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  import re 
  8  import sys 
  9  import os.path 
 10  import unittest 
 11  import threading 
 12   
 13  this_dir = os.path.dirname(__file__) 
 14  if this_dir not in sys.path: 
 15      sys.path.insert(0, this_dir) # needed for Py3 
 16   
 17  from common_imports import etree, HelperTestCase, BytesIO, _bytes 
 18   
 19  try: 
 20      from Queue import Queue 
 21  except ImportError: 
 22      from queue import Queue # Py3 
 23   
 24   
25 -class ThreadingTestCase(HelperTestCase):
26 """Threading tests""" 27 etree = etree 28
29 - def _run_thread(self, func):
30 thread = threading.Thread(target=func) 31 thread.start() 32 thread.join()
33
34 - def _run_threads(self, count, func, main_func=None):
35 sync = threading.Event() 36 lock = threading.Lock() 37 counter = dict(started=0, finished=0, failed=0) 38 39 def sync_start(func): 40 with lock: 41 started = counter['started'] + 1 42 counter['started'] = started 43 if started < count + (main_func is not None): 44 sync.wait(4) # wait until the other threads have started up 45 assert sync.is_set() 46 sync.set() # all waiting => go! 47 try: 48 func() 49 except: 50 with lock: 51 counter['failed'] += 1 52 raise 53 else: 54 with lock: 55 counter['finished'] += 1
56 57 threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 58 for thread in threads: 59 thread.start() 60 if main_func is not None: 61 sync_start(main_func) 62 for thread in threads: 63 thread.join() 64 65 self.assertEqual(0, counter['failed']) 66 self.assertEqual(counter['finished'], counter['started'])
67
68 - def test_subtree_copy_thread(self):
69 tostring = self.etree.tostring 70 XML = self.etree.XML 71 xml = _bytes("<root><threadtag/></root>") 72 main_root = XML(_bytes("<root/>")) 73 74 def run_thread(): 75 thread_root = XML(xml) 76 main_root.append(thread_root[0]) 77 del thread_root
78 79 self._run_thread(run_thread) 80 self.assertEqual(xml, tostring(main_root)) 81
82 - def test_main_xslt_in_thread(self):
83 XML = self.etree.XML 84 style = XML(_bytes('''\ 85 <xsl:stylesheet version="1.0" 86 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 87 <xsl:template match="*"> 88 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 89 </xsl:template> 90 </xsl:stylesheet>''')) 91 st = etree.XSLT(style) 92 93 result = [] 94 95 def run_thread(): 96 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 97 result.append( st(root) )
98 99 self._run_thread(run_thread) 100 self.assertEqual('''\ 101 <?xml version="1.0"?> 102 <foo><a>B</a></foo> 103 ''', 104 str(result[0])) 105
106 - def test_thread_xslt(self):
107 XML = self.etree.XML 108 tostring = self.etree.tostring 109 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 110 111 def run_thread(): 112 style = XML(_bytes('''\ 113 <xsl:stylesheet version="1.0" 114 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 115 <xsl:template match="*"> 116 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 117 </xsl:template> 118 </xsl:stylesheet>''')) 119 st = etree.XSLT(style) 120 root.append( st(root).getroot() )
121 122 self._run_thread(run_thread) 123 self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 124 tostring(root)) 125
126 - def test_thread_xslt_parsing_error_log(self):
127 style = self.parse('''\ 128 <xsl:stylesheet version="1.0" 129 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 130 <xsl:template match="tag" /> 131 <!-- extend time for parsing + transform --> 132 ''' + '\n'.join('<xsl:template match="tag%x" />' % i for i in range(200)) + ''' 133 <xsl:UnExpectedElement /> 134 </xsl:stylesheet>''') 135 self.assertRaises(etree.XSLTParseError, 136 etree.XSLT, style) 137 138 error_logs = [] 139 140 def run_thread(): 141 try: 142 etree.XSLT(style) 143 except etree.XSLTParseError as e: 144 error_logs.append(e.error_log) 145 else: 146 self.assertFalse(True, "XSLT parsing should have failed but didn't")
147 148 self._run_threads(16, run_thread) 149 150 self.assertEqual(16, len(error_logs)) 151 last_log = None 152 for log in error_logs: 153 self.assertTrue(len(log)) 154 if last_log is not None: 155 self.assertEqual(len(last_log), len(log)) 156 self.assertTrue(len(log) >= 2, len(log)) 157 for error in log: 158 self.assertTrue(':ERROR:XSLT:' in str(error), str(error)) 159 self.assertTrue(any('UnExpectedElement' in str(error) for error in log), log) 160 last_log = log 161
162 - def test_thread_xslt_apply_error_log(self):
163 tree = self.parse('<tagFF/>') 164 style = self.parse('''\ 165 <xsl:stylesheet version="1.0" 166 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 167 <xsl:template name="tag0"> 168 <xsl:message terminate="yes">FAIL</xsl:message> 169 </xsl:template> 170 <!-- extend time for parsing + transform --> 171 ''' + '\n'.join('<xsl:template match="tag%X" name="tag%x"> <xsl:call-template name="tag%x" /> </xsl:template>' % (i, i, i-1) 172 for i in range(1, 256)) + ''' 173 </xsl:stylesheet>''') 174 self.assertRaises(etree.XSLTApplyError, 175 etree.XSLT(style), tree) 176 177 error_logs = [] 178 179 def run_thread(): 180 transform = etree.XSLT(style) 181 try: 182 transform(tree) 183 except etree.XSLTApplyError: 184 error_logs.append(transform.error_log) 185 else: 186 self.assertFalse(True, "XSLT parsing should have failed but didn't")
187 188 self._run_threads(16, run_thread) 189 190 self.assertEqual(16, len(error_logs)) 191 last_log = None 192 for log in error_logs: 193 self.assertTrue(len(log)) 194 if last_log is not None: 195 self.assertEqual(len(last_log), len(log)) 196 self.assertEqual(1, len(log)) 197 for error in log: 198 self.assertTrue(':ERROR:XSLT:' in str(error)) 199 last_log = log 200
201 - def test_thread_xslt_attr_replace(self):
202 # this is the only case in XSLT where the result tree can be 203 # modified in-place 204 XML = self.etree.XML 205 tostring = self.etree.tostring 206 style = self.etree.XSLT(XML(_bytes('''\ 207 <xsl:stylesheet version="1.0" 208 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 209 <xsl:template match="*"> 210 <root class="abc"> 211 <xsl:copy-of select="@class" /> 212 <xsl:attribute name="class">xyz</xsl:attribute> 213 </root> 214 </xsl:template> 215 </xsl:stylesheet>'''))) 216 217 result = [] 218 def run_thread(): 219 root = XML(_bytes('<ROOT class="ABC" />')) 220 result.append( style(root).getroot() )
221 222 self._run_thread(run_thread) 223 self.assertEqual(_bytes('<root class="xyz"/>'), 224 tostring(result[0])) 225
226 - def test_thread_create_xslt(self):
227 XML = self.etree.XML 228 tostring = self.etree.tostring 229 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 230 231 stylesheets = [] 232 233 def run_thread(): 234 style = XML(_bytes('''\ 235 <xsl:stylesheet 236 xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 237 version="1.0"> 238 <xsl:output method="xml" /> 239 <xsl:template match="/"> 240 <div id="test"> 241 <xsl:apply-templates/> 242 </div> 243 </xsl:template> 244 </xsl:stylesheet>''')) 245 stylesheets.append( etree.XSLT(style) )
246 247 self._run_thread(run_thread) 248 249 st = stylesheets[0] 250 result = tostring( st(root) ) 251 252 self.assertEqual(_bytes('<div id="test">BC</div>'), 253 result) 254
255 - def test_thread_error_log(self):
256 XML = self.etree.XML 257 expected_error = [self.etree.ErrorTypes.ERR_TAG_NAME_MISMATCH] 258 children = "<a>test</a>" * 100 259 260 def parse_error_test(thread_no): 261 tag = "tag%d" % thread_no 262 xml = "<%s>%s</%s>" % (tag, children, tag.upper()) 263 parser = self.etree.XMLParser() 264 for _ in range(10): 265 errors = None 266 try: 267 XML(xml, parser) 268 except self.etree.ParseError: 269 e = sys.exc_info()[1] 270 errors = e.error_log.filter_types(expected_error) 271 self.assertTrue(errors, "Expected error not found") 272 for error in errors: 273 self.assertTrue( 274 tag in error.message and tag.upper() in error.message, 275 "%s and %s not found in '%s'" % ( 276 tag, tag.upper(), error.message))
277 278 self.etree.clear_error_log() 279 threads = [] 280 for thread_no in range(1, 10): 281 t = threading.Thread(target=parse_error_test, 282 args=(thread_no,)) 283 threads.append(t) 284 t.start() 285 286 parse_error_test(0) 287 288 for t in threads: 289 t.join() 290
291 - def test_thread_mix(self):
292 XML = self.etree.XML 293 Element = self.etree.Element 294 SubElement = self.etree.SubElement 295 tostring = self.etree.tostring 296 xml = _bytes('<a><b>B</b><c xmlns="test">C</c></a>') 297 root = XML(xml) 298 fragment = XML(_bytes("<other><tags/></other>")) 299 300 result = self.etree.Element("{myns}root", att = "someval") 301 302 def run_XML(): 303 thread_root = XML(xml) 304 result.append(thread_root[0]) 305 result.append(thread_root[-1])
306 307 def run_parse(): 308 thread_root = self.etree.parse(BytesIO(xml)).getroot() 309 result.append(thread_root[0]) 310 result.append(thread_root[-1]) 311 312 def run_move_main(): 313 result.append(fragment[0]) 314 315 def run_build(): 316 result.append( 317 Element("{myns}foo", attrib={'{test}attr':'val'})) 318 SubElement(result, "{otherns}tasty") 319 320 def run_xslt(): 321 style = XML(_bytes('''\ 322 <xsl:stylesheet version="1.0" 323 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 324 <xsl:template match="*"> 325 <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 326 </xsl:template> 327 </xsl:stylesheet>''')) 328 st = etree.XSLT(style) 329 result.append( st(root).getroot() ) 330 331 for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 332 tostring(result) 333 self._run_thread(test) 334 335 self.assertEqual( 336 _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 337 '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 338 '<a><foo>B</foo></a>' 339 '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 340 '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 341 tostring(result)) 342 343 def strip_first(): 344 root = Element("newroot") 345 root.append(result[0]) 346 347 while len(result): 348 self._run_thread(strip_first) 349 350 self.assertEqual( 351 _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 352 tostring(result)) 353
354 - def test_concurrent_attribute_names_in_dicts(self):
355 SubElement = self.etree.SubElement 356 names = list('abcdefghijklmnop') 357 runs_per_name = range(50) 358 result_matches = re.compile( 359 br'<thread_root>' 360 br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 361 br'</thread_root>').match 362 363 def testrun(): 364 for _ in range(3): 365 root = self.etree.Element('thread_root') 366 for name in names: 367 tag_name = name * 5 368 new = [] 369 for _ in runs_per_name: 370 el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 371 new.append(el) 372 for el in new: 373 el.set('thread_attr2_' + name, 'value2') 374 s = etree.tostring(root) 375 self.assertTrue(result_matches(s))
376 377 # first, run only in sub-threads 378 self._run_threads(10, testrun) 379 380 # then, additionally include the main thread (and its parent dict) 381 self._run_threads(10, testrun, main_func=testrun) 382
383 - def test_concurrent_proxies(self):
384 XML = self.etree.XML 385 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 386 child_count = len(root) 387 def testrun(): 388 for i in range(10000): 389 el = root[i%child_count] 390 del el
391 self._run_threads(10, testrun) 392
393 - def test_concurrent_class_lookup(self):
394 XML = self.etree.XML 395 396 class TestElement(etree.ElementBase): 397 pass
398 399 class MyLookup(etree.CustomElementClassLookup): 400 repeat = range(100) 401 def lookup(self, t, d, ns, name): 402 count = 0 403 for i in self.repeat: 404 # allow other threads to run 405 count += 1 406 return TestElement 407 408 parser = self.etree.XMLParser() 409 parser.set_element_class_lookup(MyLookup()) 410 411 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 412 parser) 413 414 child_count = len(root) 415 def testrun(): 416 for i in range(1000): 417 el = root[i%child_count] 418 del el 419 self._run_threads(10, testrun) 420 421
422 -class ThreadPipelineTestCase(HelperTestCase):
423 """Threading tests based on a thread worker pipeline. 424 """ 425 etree = etree 426 item_count = 40 427
428 - class Worker(threading.Thread):
429 - def __init__(self, in_queue, in_count, **kwargs):
430 threading.Thread.__init__(self) 431 self.in_queue = in_queue 432 self.in_count = in_count 433 self.out_queue = Queue(in_count) 434 self.__dict__.update(kwargs)
435
436 - def run(self):
437 get, put = self.in_queue.get, self.out_queue.put 438 handle = self.handle 439 for _ in range(self.in_count): 440 put(handle(get()))
441
442 - def handle(self, data):
443 raise NotImplementedError()
444
445 - class ParseWorker(Worker):
446 - def handle(self, xml, _fromstring=etree.fromstring):
447 return _fromstring(xml)
448
449 - class RotateWorker(Worker):
450 - def handle(self, element):
451 first = element[0] 452 element[:] = element[1:] 453 element.append(first) 454 return element
455
456 - class ReverseWorker(Worker):
457 - def handle(self, element):
458 element[:] = element[::-1] 459 return element
460
461 - class ParseAndExtendWorker(Worker):
462 - def handle(self, element, _fromstring=etree.fromstring):
463 element.extend(_fromstring(self.xml)) 464 return element
465
466 - class ParseAndInjectWorker(Worker):
467 - def handle(self, element, _fromstring=etree.fromstring):
468 root = _fromstring(self.xml) 469 root.extend(element) 470 return root
471
472 - class Validate(Worker):
473 - def handle(self, element):
476
477 - class SerialiseWorker(Worker):
478 - def handle(self, element):
479 return etree.tostring(element)
480 481 xml = (b'''\ 482 <!DOCTYPE threadtest [ 483 <!ELEMENT threadtest (thread-tag1,thread-tag2)+> 484 <!ATTLIST threadtest 485 version CDATA "1.0" 486 > 487 <!ELEMENT thread-tag1 EMPTY> 488 <!ELEMENT thread-tag2 (div)> 489 <!ELEMENT div (threaded)> 490 <!ATTLIST div 491 huhu CDATA #IMPLIED 492 > 493 <!ELEMENT threaded EMPTY> 494 <!ATTLIST threaded 495 host CDATA #REQUIRED 496 > 497 ]> 498 <threadtest version="123"> 499 ''' + (b''' 500 <thread-tag1 /> 501 <thread-tag2> 502 <div huhu="true"> 503 <threaded host="here" /> 504 </div> 505 </thread-tag2> 506 ''') * 20 + b''' 507 </threadtest>''') 508
509 - def _build_pipeline(self, item_count, *classes, **kwargs):
510 in_queue = Queue(item_count) 511 start = last = classes[0](in_queue, item_count, **kwargs) 512 start.setDaemon(True) 513 for worker_class in classes[1:]: 514 last = worker_class(last.out_queue, item_count, **kwargs) 515 last.setDaemon(True) 516 last.start() 517 return in_queue, start, last
518
520 item_count = self.item_count 521 xml = self.xml.replace(b'thread', b'THREAD') # use fresh tag names 522 523 # build and start the pipeline 524 in_queue, start, last = self._build_pipeline( 525 item_count, 526 self.ParseWorker, 527 self.RotateWorker, 528 self.ReverseWorker, 529 self.ParseAndExtendWorker, 530 self.Validate, 531 self.ParseAndInjectWorker, 532 self.SerialiseWorker, 533 xml=xml) 534 535 # fill the queue 536 put = start.in_queue.put 537 for _ in range(item_count): 538 put(xml) 539 540 # start the first thread and thus everything 541 start.start() 542 # make sure the last thread has terminated 543 last.join(60) # time out after 60 seconds 544 self.assertEqual(item_count, last.out_queue.qsize()) 545 # read the results 546 get = last.out_queue.get 547 results = [get() for _ in range(item_count)] 548 549 comparison = results[0] 550 for i, result in enumerate(results[1:]): 551 self.assertEqual(comparison, result)
552
554 item_count = self.item_count 555 xml = self.xml.replace(b'thread', b'GLOBAL') # use fresh tag names 556 XML = self.etree.XML 557 # build and start the pipeline 558 in_queue, start, last = self._build_pipeline( 559 item_count, 560 self.RotateWorker, 561 self.ReverseWorker, 562 self.ParseAndExtendWorker, 563 self.Validate, 564 self.SerialiseWorker, 565 xml=xml) 566 567 # fill the queue 568 put = start.in_queue.put 569 for _ in range(item_count): 570 put(XML(xml)) 571 572 # start the first thread and thus everything 573 start.start() 574 # make sure the last thread has terminated 575 last.join(60) # time out after 90 seconds 576 self.assertEqual(item_count, last.out_queue.qsize()) 577 # read the results 578 get = last.out_queue.get 579 results = [get() for _ in range(item_count)] 580 581 comparison = results[0] 582 for i, result in enumerate(results[1:]): 583 self.assertEqual(comparison, result)
584 585
586 -def test_suite():
587 suite = unittest.TestSuite() 588 suite.addTests([unittest.makeSuite(ThreadingTestCase)]) 589 suite.addTests([unittest.makeSuite(ThreadPipelineTestCase)]) 590 return suite
591 592 if __name__ == '__main__': 593 print('to test use test.py %s' % __file__) 594