Package lxml :: Package tests :: Module test_threading
[hide private]
[frames] | no frames]

Source Code for Module lxml.tests.test_threading

  1  # -*- coding: utf-8 -*- 
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  import re 
  8  import sys 
  9  import os.path 
 10  import unittest 
 11  import threading 
 12   
 13  this_dir = os.path.dirname(__file__) 
 14  if this_dir not in sys.path: 
 15      sys.path.insert(0, this_dir) # needed for Py3 
 16   
 17  from common_imports import etree, HelperTestCase, BytesIO, _bytes 
 18   
 19  try: 
 20      from Queue import Queue 
 21  except ImportError: 
 22      from queue import Queue # Py3 
 23   
 24   
25 -class ThreadingTestCase(HelperTestCase):
26 """Threading tests""" 27 etree = etree 28
29 - def _run_thread(self, func):
30 thread = threading.Thread(target=func) 31 thread.start() 32 thread.join()
33
34 - def _run_threads(self, count, func, main_func=None):
35 sync = threading.Event() 36 lock = threading.Lock() 37 counter = dict(started=0, finished=0, failed=0) 38 39 def sync_start(func): 40 with lock: 41 started = counter['started'] + 1 42 counter['started'] = started 43 if started < count + (main_func is not None): 44 sync.wait(4) # wait until the other threads have started up 45 assert sync.is_set() 46 sync.set() # all waiting => go! 47 try: 48 func() 49 except: 50 with lock: 51 counter['failed'] += 1 52 raise 53 else: 54 with lock: 55 counter['finished'] += 1
56 57 threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 58 for thread in threads: 59 thread.start() 60 if main_func is not None: 61 sync_start(main_func) 62 for thread in threads: 63 thread.join() 64 65 self.assertEqual(0, counter['failed']) 66 self.assertEqual(counter['finished'], counter['started'])
67
68 - def test_subtree_copy_thread(self):
69 tostring = self.etree.tostring 70 XML = self.etree.XML 71 xml = _bytes("<root><threadtag/></root>") 72 main_root = XML(_bytes("<root/>")) 73 74 def run_thread(): 75 thread_root = XML(xml) 76 main_root.append(thread_root[0]) 77 del thread_root
78 79 self._run_thread(run_thread) 80 self.assertEqual(xml, tostring(main_root)) 81
82 - def test_main_xslt_in_thread(self):
83 XML = self.etree.XML 84 style = XML(_bytes('''\ 85 <xsl:stylesheet version="1.0" 86 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 87 <xsl:template match="*"> 88 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 89 </xsl:template> 90 </xsl:stylesheet>''')) 91 st = etree.XSLT(style) 92 93 result = [] 94 95 def run_thread(): 96 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 97 result.append( st(root) )
98 99 self._run_thread(run_thread) 100 self.assertEqual('''\ 101 <?xml version="1.0"?> 102 <foo><a>B</a></foo> 103 ''', 104 str(result[0])) 105
106 - def test_thread_xslt(self):
107 XML = self.etree.XML 108 tostring = self.etree.tostring 109 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 110 111 def run_thread(): 112 style = XML(_bytes('''\ 113 <xsl:stylesheet version="1.0" 114 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 115 <xsl:template match="*"> 116 <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 117 </xsl:template> 118 </xsl:stylesheet>''')) 119 st = etree.XSLT(style) 120 root.append( st(root).getroot() )
121 122 self._run_thread(run_thread) 123 self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 124 tostring(root)) 125
126 - def test_thread_xslt_attr_replace(self):
127 # this is the only case in XSLT where the result tree can be 128 # modified in-place 129 XML = self.etree.XML 130 tostring = self.etree.tostring 131 style = self.etree.XSLT(XML(_bytes('''\ 132 <xsl:stylesheet version="1.0" 133 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 134 <xsl:template match="*"> 135 <root class="abc"> 136 <xsl:copy-of select="@class" /> 137 <xsl:attribute name="class">xyz</xsl:attribute> 138 </root> 139 </xsl:template> 140 </xsl:stylesheet>'''))) 141 142 result = [] 143 def run_thread(): 144 root = XML(_bytes('<ROOT class="ABC" />')) 145 result.append( style(root).getroot() )
146 147 self._run_thread(run_thread) 148 self.assertEqual(_bytes('<root class="xyz"/>'), 149 tostring(result[0])) 150
151 - def test_thread_create_xslt(self):
152 XML = self.etree.XML 153 tostring = self.etree.tostring 154 root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 155 156 stylesheets = [] 157 158 def run_thread(): 159 style = XML(_bytes('''\ 160 <xsl:stylesheet 161 xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 162 version="1.0"> 163 <xsl:output method="xml" /> 164 <xsl:template match="/"> 165 <div id="test"> 166 <xsl:apply-templates/> 167 </div> 168 </xsl:template> 169 </xsl:stylesheet>''')) 170 stylesheets.append( etree.XSLT(style) )
171 172 self._run_thread(run_thread) 173 174 st = stylesheets[0] 175 result = tostring( st(root) ) 176 177 self.assertEqual(_bytes('<div id="test">BC</div>'), 178 result) 179
180 - def test_thread_error_log(self):
181 XML = self.etree.XML 182 ParseError = self.etree.ParseError 183 expected_error = [self.etree.ErrorTypes.ERR_TAG_NAME_MISMATCH] 184 children = "<a>test</a>" * 100 185 186 def parse_error_test(thread_no): 187 tag = "tag%d" % thread_no 188 xml = "<%s>%s</%s>" % (tag, children, tag.upper()) 189 parser = self.etree.XMLParser() 190 for _ in range(10): 191 errors = None 192 try: 193 XML(xml, parser) 194 except self.etree.ParseError: 195 e = sys.exc_info()[1] 196 errors = e.error_log.filter_types(expected_error) 197 self.assertTrue(errors, "Expected error not found") 198 for error in errors: 199 self.assertTrue( 200 tag in error.message and tag.upper() in error.message, 201 "%s and %s not found in '%s'" % ( 202 tag, tag.upper(), error.message))
203 204 self.etree.clear_error_log() 205 threads = [] 206 for thread_no in range(1, 10): 207 t = threading.Thread(target=parse_error_test, 208 args=(thread_no,)) 209 threads.append(t) 210 t.start() 211 212 parse_error_test(0) 213 214 for t in threads: 215 t.join() 216
217 - def test_thread_mix(self):
218 XML = self.etree.XML 219 Element = self.etree.Element 220 SubElement = self.etree.SubElement 221 tostring = self.etree.tostring 222 xml = _bytes('<a><b>B</b><c xmlns="test">C</c></a>') 223 root = XML(xml) 224 fragment = XML(_bytes("<other><tags/></other>")) 225 226 result = self.etree.Element("{myns}root", att = "someval") 227 228 def run_XML(): 229 thread_root = XML(xml) 230 result.append(thread_root[0]) 231 result.append(thread_root[-1])
232 233 def run_parse(): 234 thread_root = self.etree.parse(BytesIO(xml)).getroot() 235 result.append(thread_root[0]) 236 result.append(thread_root[-1]) 237 238 def run_move_main(): 239 result.append(fragment[0]) 240 241 def run_build(): 242 result.append( 243 Element("{myns}foo", attrib={'{test}attr':'val'})) 244 SubElement(result, "{otherns}tasty") 245 246 def run_xslt(): 247 style = XML(_bytes('''\ 248 <xsl:stylesheet version="1.0" 249 xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 250 <xsl:template match="*"> 251 <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 252 </xsl:template> 253 </xsl:stylesheet>''')) 254 st = etree.XSLT(style) 255 result.append( st(root).getroot() ) 256 257 for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 258 tostring(result) 259 self._run_thread(test) 260 261 self.assertEqual( 262 _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 263 '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 264 '<a><foo>B</foo></a>' 265 '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 266 '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 267 tostring(result)) 268 269 def strip_first(): 270 root = Element("newroot") 271 root.append(result[0]) 272 273 while len(result): 274 self._run_thread(strip_first) 275 276 self.assertEqual( 277 _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 278 tostring(result)) 279
280 - def test_concurrent_attribute_names_in_dicts(self):
281 SubElement = self.etree.SubElement 282 names = list('abcdefghijklmnop') 283 runs_per_name = range(50) 284 result_matches = re.compile( 285 br'<thread_root>' 286 br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 287 br'</thread_root>').match 288 289 def testrun(): 290 for _ in range(3): 291 root = self.etree.Element('thread_root') 292 for name in names: 293 tag_name = name * 5 294 new = [] 295 for _ in runs_per_name: 296 el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 297 new.append(el) 298 for el in new: 299 el.set('thread_attr2_' + name, 'value2') 300 s = etree.tostring(root) 301 self.assertTrue(result_matches(s))
302 303 # first, run only in sub-threads 304 self._run_threads(10, testrun) 305 306 # then, additionally include the main thread (and its parent dict) 307 self._run_threads(10, testrun, main_func=testrun) 308
309 - def test_concurrent_proxies(self):
310 XML = self.etree.XML 311 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 312 child_count = len(root) 313 def testrun(): 314 for i in range(10000): 315 el = root[i%child_count] 316 del el
317 self._run_threads(10, testrun) 318
319 - def test_concurrent_class_lookup(self):
320 XML = self.etree.XML 321 322 class TestElement(etree.ElementBase): 323 pass
324 325 class MyLookup(etree.CustomElementClassLookup): 326 repeat = range(100) 327 def lookup(self, t, d, ns, name): 328 count = 0 329 for i in self.repeat: 330 # allow other threads to run 331 count += 1 332 return TestElement 333 334 parser = self.etree.XMLParser() 335 parser.set_element_class_lookup(MyLookup()) 336 337 root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 338 parser) 339 340 child_count = len(root) 341 def testrun(): 342 for i in range(1000): 343 el = root[i%child_count] 344 del el 345 self._run_threads(10, testrun) 346 347
348 -class ThreadPipelineTestCase(HelperTestCase):
349 """Threading tests based on a thread worker pipeline. 350 """ 351 etree = etree 352 item_count = 40 353
354 - class Worker(threading.Thread):
355 - def __init__(self, in_queue, in_count, **kwargs):
356 threading.Thread.__init__(self) 357 self.in_queue = in_queue 358 self.in_count = in_count 359 self.out_queue = Queue(in_count) 360 self.__dict__.update(kwargs)
361
362 - def run(self):
363 get, put = self.in_queue.get, self.out_queue.put 364 handle = self.handle 365 for _ in range(self.in_count): 366 put(handle(get()))
367
368 - def handle(self, data):
369 raise NotImplementedError()
370
371 - class ParseWorker(Worker):
372 - def handle(self, xml, _fromstring=etree.fromstring):
373 return _fromstring(xml)
374
375 - class RotateWorker(Worker):
376 - def handle(self, element):
377 first = element[0] 378 element[:] = element[1:] 379 element.append(first) 380 return element
381
382 - class ReverseWorker(Worker):
383 - def handle(self, element):
384 element[:] = element[::-1] 385 return element
386
387 - class ParseAndExtendWorker(Worker):
388 - def handle(self, element, _fromstring=etree.fromstring):
389 element.extend(_fromstring(self.xml)) 390 return element
391
392 - class ParseAndInjectWorker(Worker):
393 - def handle(self, element, _fromstring=etree.fromstring):
394 root = _fromstring(self.xml) 395 root.extend(element) 396 return root
397
398 - class Validate(Worker):
399 - def handle(self, element):
402
403 - class SerialiseWorker(Worker):
404 - def handle(self, element):
405 return etree.tostring(element)
406 407 xml = (b'''\ 408 <!DOCTYPE threadtest [ 409 <!ELEMENT threadtest (thread-tag1,thread-tag2)+> 410 <!ATTLIST threadtest 411 version CDATA "1.0" 412 > 413 <!ELEMENT thread-tag1 EMPTY> 414 <!ELEMENT thread-tag2 (div)> 415 <!ELEMENT div (threaded)> 416 <!ATTLIST div 417 huhu CDATA #IMPLIED 418 > 419 <!ELEMENT threaded EMPTY> 420 <!ATTLIST threaded 421 host CDATA #REQUIRED 422 > 423 ]> 424 <threadtest version="123"> 425 ''' + (b''' 426 <thread-tag1 /> 427 <thread-tag2> 428 <div huhu="true"> 429 <threaded host="here" /> 430 </div> 431 </thread-tag2> 432 ''') * 20 + b''' 433 </threadtest>''') 434
435 - def _build_pipeline(self, item_count, *classes, **kwargs):
436 in_queue = Queue(item_count) 437 start = last = classes[0](in_queue, item_count, **kwargs) 438 start.setDaemon(True) 439 for worker_class in classes[1:]: 440 last = worker_class(last.out_queue, item_count, **kwargs) 441 last.setDaemon(True) 442 last.start() 443 return (in_queue, start, last)
444
446 item_count = self.item_count 447 xml = self.xml.replace(b'thread', b'THREAD') # use fresh tag names 448 449 # build and start the pipeline 450 in_queue, start, last = self._build_pipeline( 451 item_count, 452 self.ParseWorker, 453 self.RotateWorker, 454 self.ReverseWorker, 455 self.ParseAndExtendWorker, 456 self.Validate, 457 self.ParseAndInjectWorker, 458 self.SerialiseWorker, 459 xml=xml) 460 461 # fill the queue 462 put = start.in_queue.put 463 for _ in range(item_count): 464 put(xml) 465 466 # start the first thread and thus everything 467 start.start() 468 # make sure the last thread has terminated 469 last.join(60) # time out after 60 seconds 470 self.assertEqual(item_count, last.out_queue.qsize()) 471 # read the results 472 get = last.out_queue.get 473 results = [get() for _ in range(item_count)] 474 475 comparison = results[0] 476 for i, result in enumerate(results[1:]): 477 self.assertEqual(comparison, result)
478
480 item_count = self.item_count 481 xml = self.xml.replace(b'thread', b'GLOBAL') # use fresh tag names 482 XML = self.etree.XML 483 # build and start the pipeline 484 in_queue, start, last = self._build_pipeline( 485 item_count, 486 self.RotateWorker, 487 self.ReverseWorker, 488 self.ParseAndExtendWorker, 489 self.Validate, 490 self.SerialiseWorker, 491 xml=xml) 492 493 # fill the queue 494 put = start.in_queue.put 495 for _ in range(item_count): 496 put(XML(xml)) 497 498 # start the first thread and thus everything 499 start.start() 500 # make sure the last thread has terminated 501 last.join(60) # time out after 90 seconds 502 self.assertEqual(item_count, last.out_queue.qsize()) 503 # read the results 504 get = last.out_queue.get 505 results = [get() for _ in range(item_count)] 506 507 comparison = results[0] 508 for i, result in enumerate(results[1:]): 509 self.assertEqual(comparison, result)
510 511
512 -def test_suite():
513 suite = unittest.TestSuite() 514 suite.addTests([unittest.makeSuite(ThreadingTestCase)]) 515 suite.addTests([unittest.makeSuite(ThreadPipelineTestCase)]) 516 return suite
517 518 if __name__ == '__main__': 519 print('to test use test.py %s' % __file__) 520