@@ -5584,3 +5584,87 @@ def test_mixed_sample_status(self):
55845584 expected = np .array ([[0 , 1 ]])
55855585 assert result .shape == (1 , 2 )
55865586 assert_array_equal (result , expected )
5587+
5588+
5589+ class TestSampleNodesByPloidy :
5590+ @pytest .mark .parametrize (
5591+ "n_samples,ploidy,expected" ,
5592+ [
5593+ (6 , 2 , np .array ([[0 , 1 ], [2 , 3 ], [4 , 5 ]])), # Basic diploid
5594+ (9 , 3 , np .array ([[0 , 1 , 2 ], [3 , 4 , 5 ], [6 , 7 , 8 ]])), # Triploid
5595+ (5 , 1 , np .array ([[0 ], [1 ], [2 ], [3 ], [4 ]])), # Ploidy of 1
5596+ (4 , 4 , np .array ([[0 , 1 , 2 , 3 ]])), # Ploidy equals number of samples
5597+ ],
5598+ )
5599+ def test_various_ploidy_scenarios (self , n_samples , ploidy , expected ):
5600+ tables = tskit .TableCollection (sequence_length = 100 )
5601+ for _ in range (n_samples ):
5602+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5603+ ts = tables .tree_sequence ()
5604+
5605+ result = ts .sample_nodes_by_ploidy (ploidy )
5606+ expected_shape = (n_samples // ploidy , ploidy )
5607+ assert result .shape == expected_shape
5608+ assert_array_equal (result , expected )
5609+
5610+ def test_mixed_sample_status (self ):
5611+ tables = tskit .TableCollection (sequence_length = 100 )
5612+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5613+ tables .nodes .add_row (flags = 0 , time = 0 )
5614+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5615+ tables .nodes .add_row (flags = 0 , time = 0 )
5616+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5617+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5618+ ts = tables .tree_sequence ()
5619+
5620+ result = ts .sample_nodes_by_ploidy (2 )
5621+ assert result .shape == (2 , 2 )
5622+ expected = np .array ([[0 , 2 ], [4 , 5 ]])
5623+ assert_array_equal (result , expected )
5624+
5625+ def test_no_sample_nodes (self ):
5626+ tables = tskit .TableCollection (sequence_length = 100 )
5627+ tables .nodes .add_row (flags = 0 , time = 0 )
5628+ tables .nodes .add_row (flags = 0 , time = 0 )
5629+ ts = tables .tree_sequence ()
5630+
5631+ with pytest .raises (ValueError , match = "No sample nodes in tree sequence" ):
5632+ ts .sample_nodes_by_ploidy (2 )
5633+
5634+ def test_not_multiple_of_ploidy (self ):
5635+ tables = tskit .TableCollection (sequence_length = 100 )
5636+ for _ in range (5 ):
5637+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5638+ ts = tables .tree_sequence ()
5639+
5640+ with pytest .raises (ValueError , match = "not a multiple of ploidy" ):
5641+ ts .sample_nodes_by_ploidy (2 )
5642+
5643+ def test_with_existing_individuals (self ):
5644+ tables = tskit .TableCollection (sequence_length = 100 )
5645+ tables .individuals .add_row (flags = 0 , location = (0 , 0 ), metadata = b"" )
5646+ tables .individuals .add_row (flags = 0 , location = (0 , 0 ), metadata = b"" )
5647+ # Add nodes with individual references but in a different order
5648+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 1 )
5649+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 0 )
5650+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 1 )
5651+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 0 )
5652+
5653+ ts = tables .tree_sequence ()
5654+ result = ts .sample_nodes_by_ploidy (2 )
5655+ expected = np .array ([[0 , 1 ], [2 , 3 ]])
5656+ assert_array_equal (result , expected )
5657+ ind_nodes = ts .individuals_nodes
5658+ assert not np .array_equal (result , ind_nodes )
5659+
5660+ def test_different_node_flags (self ):
5661+ tables = tskit .TableCollection (sequence_length = 100 )
5662+ OTHER_FLAG1 = 1 << 1
5663+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5664+ tables .nodes .add_row (flags = OTHER_FLAG1 , time = 0 )
5665+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE | OTHER_FLAG1 , time = 0 )
5666+ tables .nodes .add_row ()
5667+ ts = tables .tree_sequence ()
5668+ result = ts .sample_nodes_by_ploidy (2 )
5669+ assert result .shape == (1 , 2 )
5670+ assert_array_equal (result , np .array ([[0 , 2 ]]))
0 commit comments