Package lxml :: Package tests :: Module test_threading
[hide private]
[frames] | no frames]

Source Code for Module lxml.tests.test_threading

  1  # -*- coding: utf-8 -*- 
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  import re 
  8  import sys 
  9  import os.path 
 10  import unittest 
 11  import threading 
 12   
 13  this_dir = os.path.dirname(__file__) 
 14  if this_dir not in sys.path: 
 15      sys.path.insert(0, this_dir) # needed for Py3 
 16   
 17  from common_imports import etree, HelperTestCase, BytesIO, _bytes 
 18   
 19  try: 
 20      from Queue import Queue 
 21  except ImportError: 
 22      from queue import Queue # Py3 
 23   
 24   
25 -class ThreadingTestCase(HelperTestCase):
26 """Threading tests""" 27 etree = etree 28
29 - def _run_thread(self, func):
30 thread = threading.Thread(target=func) 31 thread.start() 32 thread.join()
33
34 - def _run_threads(self, count, func, main_func=None):
35 sync = threading.Event() 36 lock = threading.Lock() 37 counter = dict(started=0, finished=0, failed=0) 38 39 def sync_start(func): 40 with lock: 41 started = counter['started'] + 1 42 counter['started'] = started 43 if started < count + (main_func is not None): 44 sync.wait(4) # wait until the other threads have started up 45 assert sync.is_set() 46 sync.set() # all waiting => go! 47 try: 48 func() 49 except: 50 with lock: 51 counter['failed'] += 1 52 raise 53 else: 54 with lock: 55 counter['finished'] += 1
56 57 threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 58 for thread in threads: 59 thread.start() 60 if main_func is not None: 61 sync_start(main_func) 62 for thread in threads: 63 thread.join() 64 65 self.assertEqual(0, counter['failed']) 66 self.assertEqual(counter['finished'], counter['started'])
67
68 - def test_subtree_copy_thread(self):
69 tostring = self.etree.tostring 70 XML = self.etree.XML 71 xml = _bytes("<root><threadtag/></root>") 72 main_root = XML(_bytes("<root/>")) 73 74 def run_thread(): 75 thread_root = XML(xml) 76 main_root.append(thread_root[0]) 77 del thread_root
78 79 self._run_thread(run_thread) 80 self.assertEqual(xml, tostring(main_root)) 81
82 - def test_main_xslt_in_thread(self):
83 XML = self.etree.XML 84 style = XML(_bytes('''\ 85 <xsl:stylesheet version="1.0" 86 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 87 <xsl:template match="*"> 88 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 89 </xsl:template> 90 </xsl:stylesheet>''')) 91 st = etree.XSLT(style) 92 93 result = [] 94 95 def run_thread(): 96 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 97 result.append( st(root) )
98 99 self._run_thread(run_thread) 100 self.assertEqual('''\ 101 <?xml version="1.0"?> 102 <foo><a>B</a></foo> 103 ''', 104 str(result[0])) 105
106 - def test_thread_xslt(self):
107 XML = self.etree.XML 108 tostring = self.etree.tostring 109 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 110 111 def run_thread(): 112 style = XML(_bytes('''\ 113 <xsl:stylesheet version="1.0" 114 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 115 <xsl:template match="*"> 116 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 117 </xsl:template> 118 </xsl:stylesheet>''')) 119 st = etree.XSLT(style) 120 root.append( st(root).getroot() )
121 122 self._run_thread(run_thread) 123 self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 124 tostring(root)) 125
126 - def test_thread_xslt_parsing_error_log(self):
127 style = self.parse('''\ 128 <xsl:stylesheet version="1.0" 129 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 130 <xsl:template match="tag" /> 131 <!-- extend time for parsing + transform --> 132 ''' + '\n'.join('<xsl:template match="tag%x" />' % i for i in range(200)) + ''' 133 <xsl:foo /> 134 </xsl:stylesheet>''') 135 self.assertRaises(etree.XSLTParseError, 136 etree.XSLT, style) 137 138 error_logs = [] 139 140 def run_thread(): 141 try: 142 etree.XSLT(style) 143 except etree.XSLTParseError as e: 144 error_logs.append(e.error_log) 145 else: 146 self.assertFalse(True, "XSLT parsing should have failed but didn't")
147 148 self._run_threads(16, run_thread) 149 150 self.assertEqual(16, len(error_logs)) 151 last_log = None 152 for log in error_logs: 153 self.assertTrue(len(log)) 154 if last_log is not None: 155 self.assertEqual(len(last_log), len(log)) 156 self.assertEqual(4, len(log)) 157 for error in log: 158 self.assertTrue(':ERROR:XSLT:' in str(error)) 159 last_log = log 160
161 - def test_thread_xslt_apply_error_log(self):
162 tree = self.parse('<tagFF/>') 163 style = self.parse('''\ 164 <xsl:stylesheet version="1.0" 165 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 166 <xsl:template name="tag0"> 167 <xsl:message terminate="yes">FAIL</xsl:message> 168 </xsl:template> 169 <!-- extend time for parsing + transform --> 170 ''' + '\n'.join('<xsl:template match="tag%X" name="tag%x"> <xsl:call-template name="tag%x" /> </xsl:template>' % (i, i, i-1) 171 for i in range(1, 256)) + ''' 172 </xsl:stylesheet>''') 173 self.assertRaises(etree.XSLTApplyError, 174 etree.XSLT(style), tree) 175 176 error_logs = [] 177 178 def run_thread(): 179 transform = etree.XSLT(style) 180 try: 181 transform(tree) 182 except etree.XSLTApplyError: 183 error_logs.append(transform.error_log) 184 else: 185 self.assertFalse(True, "XSLT parsing should have failed but didn't")
186 187 self._run_threads(16, run_thread) 188 189 self.assertEqual(16, len(error_logs)) 190 last_log = None 191 for log in error_logs: 192 self.assertTrue(len(log)) 193 if last_log is not None: 194 self.assertEqual(len(last_log), len(log)) 195 self.assertEqual(1, len(log)) 196 for error in log: 197 self.assertTrue(':ERROR:XSLT:' in str(error)) 198 last_log = log 199
200 - def test_thread_xslt_attr_replace(self):
201 # this is the only case in XSLT where the result tree can be 202 # modified in-place 203 XML = self.etree.XML 204 tostring = self.etree.tostring 205 style = self.etree.XSLT(XML(_bytes('''\ 206 <xsl:stylesheet version="1.0" 207 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 208 <xsl:template match="*"> 209 <root class="abc"> 210 <xsl:copy-of select="@class" /> 211 <xsl:attribute name="class">xyz</xsl:attribute> 212 </root> 213 </xsl:template> 214 </xsl:stylesheet>'''))) 215 216 result = [] 217 def run_thread(): 218 root = XML(_bytes('<ROOT class="ABC" />')) 219 result.append( style(root).getroot() )
220 221 self._run_thread(run_thread) 222 self.assertEqual(_bytes('<root class="xyz"/>'), 223 tostring(result[0])) 224
225 - def test_thread_create_xslt(self):
226 XML = self.etree.XML 227 tostring = self.etree.tostring 228 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 229 230 stylesheets = [] 231 232 def run_thread(): 233 style = XML(_bytes('''\ 234 <xsl:stylesheet 235 xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 236 version="1.0"> 237 <xsl:output method="xml" /> 238 <xsl:template match="/"> 239 <div id="test"> 240 <xsl:apply-templates/> 241 </div> 242 </xsl:template> 243 </xsl:stylesheet>''')) 244 stylesheets.append( etree.XSLT(style) )
245 246 self._run_thread(run_thread) 247 248 st = stylesheets[0] 249 result = tostring( st(root) ) 250 251 self.assertEqual(_bytes('<div id="test">BC</div>'), 252 result) 253
254 - def test_thread_error_log(self):
255 XML = self.etree.XML 256 expected_error = [self.etree.ErrorTypes.ERR_TAG_NAME_MISMATCH] 257 children = "<a>test</a>" * 100 258 259 def parse_error_test(thread_no): 260 tag = "tag%d" % thread_no 261 xml = "<%s>%s</%s>" % (tag, children, tag.upper()) 262 parser = self.etree.XMLParser() 263 for _ in range(10): 264 errors = None 265 try: 266 XML(xml, parser) 267 except self.etree.ParseError: 268 e = sys.exc_info()[1] 269 errors = e.error_log.filter_types(expected_error) 270 self.assertTrue(errors, "Expected error not found") 271 for error in errors: 272 self.assertTrue( 273 tag in error.message and tag.upper() in error.message, 274 "%s and %s not found in '%s'" % ( 275 tag, tag.upper(), error.message))
276 277 self.etree.clear_error_log() 278 threads = [] 279 for thread_no in range(1, 10): 280 t = threading.Thread(target=parse_error_test, 281 args=(thread_no,)) 282 threads.append(t) 283 t.start() 284 285 parse_error_test(0) 286 287 for t in threads: 288 t.join() 289
290 - def test_thread_mix(self):
291 XML = self.etree.XML 292 Element = self.etree.Element 293 SubElement = self.etree.SubElement 294 tostring = self.etree.tostring 295 xml = _bytes('<a><b>B</b><c xmlns="test">C</c></a>') 296 root = XML(xml) 297 fragment = XML(_bytes("<other><tags/></other>")) 298 299 result = self.etree.Element("{myns}root", att = "someval") 300 301 def run_XML(): 302 thread_root = XML(xml) 303 result.append(thread_root[0]) 304 result.append(thread_root[-1])
305 306 def run_parse(): 307 thread_root = self.etree.parse(BytesIO(xml)).getroot() 308 result.append(thread_root[0]) 309 result.append(thread_root[-1]) 310 311 def run_move_main(): 312 result.append(fragment[0]) 313 314 def run_build(): 315 result.append( 316 Element("{myns}foo", attrib={'{test}attr':'val'})) 317 SubElement(result, "{otherns}tasty") 318 319 def run_xslt(): 320 style = XML(_bytes('''\ 321 <xsl:stylesheet version="1.0" 322 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 323 <xsl:template match="*"> 324 <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 325 </xsl:template> 326 </xsl:stylesheet>''')) 327 st = etree.XSLT(style) 328 result.append( st(root).getroot() ) 329 330 for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 331 tostring(result) 332 self._run_thread(test) 333 334 self.assertEqual( 335 _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 336 '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 337 '<a><foo>B</foo></a>' 338 '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 339 '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 340 tostring(result)) 341 342 def strip_first(): 343 root = Element("newroot") 344 root.append(result[0]) 345 346 while len(result): 347 self._run_thread(strip_first) 348 349 self.assertEqual( 350 _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 351 tostring(result)) 352
353 - def test_concurrent_attribute_names_in_dicts(self):
354 SubElement = self.etree.SubElement 355 names = list('abcdefghijklmnop') 356 runs_per_name = range(50) 357 result_matches = re.compile( 358 br'<thread_root>' 359 br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 360 br'</thread_root>').match 361 362 def testrun(): 363 for _ in range(3): 364 root = self.etree.Element('thread_root') 365 for name in names: 366 tag_name = name * 5 367 new = [] 368 for _ in runs_per_name: 369 el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 370 new.append(el) 371 for el in new: 372 el.set('thread_attr2_' + name, 'value2') 373 s = etree.tostring(root) 374 self.assertTrue(result_matches(s))
375 376 # first, run only in sub-threads 377 self._run_threads(10, testrun) 378 379 # then, additionally include the main thread (and its parent dict) 380 self._run_threads(10, testrun, main_func=testrun) 381
382 - def test_concurrent_proxies(self):
383 XML = self.etree.XML 384 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 385 child_count = len(root) 386 def testrun(): 387 for i in range(10000): 388 el = root[i%child_count] 389 del el
390 self._run_threads(10, testrun) 391
392 - def test_concurrent_class_lookup(self):
393 XML = self.etree.XML 394 395 class TestElement(etree.ElementBase): 396 pass
397 398 class MyLookup(etree.CustomElementClassLookup): 399 repeat = range(100) 400 def lookup(self, t, d, ns, name): 401 count = 0 402 for i in self.repeat: 403 # allow other threads to run 404 count += 1 405 return TestElement 406 407 parser = self.etree.XMLParser() 408 parser.set_element_class_lookup(MyLookup()) 409 410 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 411 parser) 412 413 child_count = len(root) 414 def testrun(): 415 for i in range(1000): 416 el = root[i%child_count] 417 del el 418 self._run_threads(10, testrun) 419 420
421 -class ThreadPipelineTestCase(HelperTestCase):
422 """Threading tests based on a thread worker pipeline. 423 """ 424 etree = etree 425 item_count = 40 426
427 - class Worker(threading.Thread):
428 - def __init__(self, in_queue, in_count, **kwargs):
429 threading.Thread.__init__(self) 430 self.in_queue = in_queue 431 self.in_count = in_count 432 self.out_queue = Queue(in_count) 433 self.__dict__.update(kwargs)
434
435 - def run(self):
436 get, put = self.in_queue.get, self.out_queue.put 437 handle = self.handle 438 for _ in range(self.in_count): 439 put(handle(get()))
440
441 - def handle(self, data):
442 raise NotImplementedError()
443
444 - class ParseWorker(Worker):
445 - def handle(self, xml, _fromstring=etree.fromstring):
446 return _fromstring(xml)
447
448 - class RotateWorker(Worker):
449 - def handle(self, element):
450 first = element[0] 451 element[:] = element[1:] 452 element.append(first) 453 return element
454
455 - class ReverseWorker(Worker):
456 - def handle(self, element):
457 element[:] = element[::-1] 458 return element
459
460 - class ParseAndExtendWorker(Worker):
461 - def handle(self, element, _fromstring=etree.fromstring):
462 element.extend(_fromstring(self.xml)) 463 return element
464
465 - class ParseAndInjectWorker(Worker):
466 - def handle(self, element, _fromstring=etree.fromstring):
467 root = _fromstring(self.xml) 468 root.extend(element) 469 return root
470
471 - class Validate(Worker):
472 - def handle(self, element):
475
476 - class SerialiseWorker(Worker):
477 - def handle(self, element):
478 return etree.tostring(element)
479 480 xml = (b'''\ 481 <!DOCTYPE threadtest [ 482 <!ELEMENT threadtest (thread-tag1,thread-tag2)+> 483 <!ATTLIST threadtest 484 version CDATA "1.0" 485 > 486 <!ELEMENT thread-tag1 EMPTY> 487 <!ELEMENT thread-tag2 (div)> 488 <!ELEMENT div (threaded)> 489 <!ATTLIST div 490 huhu CDATA #IMPLIED 491 > 492 <!ELEMENT threaded EMPTY> 493 <!ATTLIST threaded 494 host CDATA #REQUIRED 495 > 496 ]> 497 <threadtest version="123"> 498 ''' + (b''' 499 <thread-tag1 /> 500 <thread-tag2> 501 <div huhu="true"> 502 <threaded host="here" /> 503 </div> 504 </thread-tag2> 505 ''') * 20 + b''' 506 </threadtest>''') 507
508 - def _build_pipeline(self, item_count, *classes, **kwargs):
509 in_queue = Queue(item_count) 510 start = last = classes[0](in_queue, item_count, **kwargs) 511 start.setDaemon(True) 512 for worker_class in classes[1:]: 513 last = worker_class(last.out_queue, item_count, **kwargs) 514 last.setDaemon(True) 515 last.start() 516 return (in_queue, start, last)
517
519 item_count = self.item_count 520 xml = self.xml.replace(b'thread', b'THREAD') # use fresh tag names 521 522 # build and start the pipeline 523 in_queue, start, last = self._build_pipeline( 524 item_count, 525 self.ParseWorker, 526 self.RotateWorker, 527 self.ReverseWorker, 528 self.ParseAndExtendWorker, 529 self.Validate, 530 self.ParseAndInjectWorker, 531 self.SerialiseWorker, 532 xml=xml) 533 534 # fill the queue 535 put = start.in_queue.put 536 for _ in range(item_count): 537 put(xml) 538 539 # start the first thread and thus everything 540 start.start() 541 # make sure the last thread has terminated 542 last.join(60) # time out after 60 seconds 543 self.assertEqual(item_count, last.out_queue.qsize()) 544 # read the results 545 get = last.out_queue.get 546 results = [get() for _ in range(item_count)] 547 548 comparison = results[0] 549 for i, result in enumerate(results[1:]): 550 self.assertEqual(comparison, result)
551
553 item_count = self.item_count 554 xml = self.xml.replace(b'thread', b'GLOBAL') # use fresh tag names 555 XML = self.etree.XML 556 # build and start the pipeline 557 in_queue, start, last = self._build_pipeline( 558 item_count, 559 self.RotateWorker, 560 self.ReverseWorker, 561 self.ParseAndExtendWorker, 562 self.Validate, 563 self.SerialiseWorker, 564 xml=xml) 565 566 # fill the queue 567 put = start.in_queue.put 568 for _ in range(item_count): 569 put(XML(xml)) 570 571 # start the first thread and thus everything 572 start.start() 573 # make sure the last thread has terminated 574 last.join(60) # time out after 90 seconds 575 self.assertEqual(item_count, last.out_queue.qsize()) 576 # read the results 577 get = last.out_queue.get 578 results = [get() for _ in range(item_count)] 579 580 comparison = results[0] 581 for i, result in enumerate(results[1:]): 582 self.assertEqual(comparison, result)
583 584
585 -def test_suite():
586 suite = unittest.TestSuite() 587 suite.addTests([unittest.makeSuite(ThreadingTestCase)]) 588 suite.addTests([unittest.makeSuite(ThreadPipelineTestCase)]) 589 return suite
590 591 if __name__ == '__main__': 592 print('to test use test.py %s' % __file__) 593