diff --git a/junitparser/junitparser.py b/junitparser/junitparser.py index a91c6a0..c73f4cb 100644 --- a/junitparser/junitparser.py +++ b/junitparser/junitparser.py @@ -514,6 +514,7 @@ def __init__(self, name=None): super().__init__(self._tag) self.name = name self.filepath = None + self.root = JUnitXml def __iter__(self) -> Iterator[TestCase]: return itertools.chain( @@ -552,7 +553,7 @@ def __add__(self, other): result.update_statistics() else: # Create a new test result containing two testsuites - result = JUnitXml() + result = self.root() result.add_testsuite(self) result.add_testsuite(other) return result @@ -566,7 +567,7 @@ def __iadd__(self, other): self.update_statistics() return self - result = JUnitXml() + result = self.root() result.filepath = self.filepath result.add_testsuite(self) result.add_testsuite(other) @@ -653,8 +654,7 @@ def remove_property(self, property_: Property): def testsuites(self): """Iterate through all testsuites.""" - for suite in self.iterchildren(TestSuite): - yield suite + yield from self.iterchildren(type(self)) def write(self, file_or_filename: Optional[Union[str, IO]] = None, *, pretty: bool = False): write_xml(self, file_or_filename=file_or_filename, pretty=pretty) @@ -696,7 +696,7 @@ def __len__(self): return len(list(self.__iter__())) def __add__(self, other): - result = JUnitXml() + result = type(self)() for suite in self: result.add_testsuite(suite) for suite in other: @@ -708,7 +708,7 @@ def __iadd__(self, other): for suite in other: self.add_testsuite(suite) elif other._elem.tag == "testsuite": - suite = TestSuite(name=other.name) + suite = self.testsuite(name=other.name) for case in other: suite._add_testcase_no_update_stats(case) self.add_testsuite(suite) diff --git a/junitparser/xunit2.py b/junitparser/xunit2.py index 96f4ad6..4f60827 100644 --- a/junitparser/xunit2.py +++ b/junitparser/xunit2.py @@ -168,6 +168,10 @@ class TestSuite(junitparser.TestSuite): testcase = TestCase + def __init__(self, name=None): + super().__init__(name) + self.root = JUnitXml + @property def system_out(self): """""" diff --git a/tests/test_xunit2.py b/tests/test_xunit2.py index 6499a52..2bc6c74 100644 --- a/tests/test_xunit2.py +++ b/tests/test_xunit2.py @@ -162,6 +162,7 @@ def test_iterate_suite(self): suite = TestSuite("mySuite") suite.add_testsuite(TestSuite("suite1")) suite = next(suite.testsuites()) + assert isinstance(suite, TestSuite) assert suite.name == "suite1" def test_remove_case(self): @@ -171,6 +172,20 @@ def test_remove_case(self): suite.remove_testcase(test) assert len(suite) == 0 + def test_add_testsuites(self): + suite1 = TestSuite("suite1") + suite2 = TestSuite("suite2") + suites = suite1 + suite2 + assert isinstance(suites, JUnitXml) + assert len(list(iter(suites))) == 2 + + def test_iadd_testsuites(self): + suite1 = TestSuite("suite1") + suite2 = TestSuite("suite2") + suite1 += suite2 + assert isinstance(suite1, JUnitXml) + assert len(list(iter(suite1))) == 2 + class Test_JUnitXml: def test_init(self):