Skip to content

Commit da79bc3

Browse files
1twodrei1twodreirlundeen2
authored
FEAT: Extend data exporter to support Markdown (Azure#1033) (Azure#1042)
Co-authored-by: 1twodrei <[email protected]> Co-authored-by: rlundeen2 <[email protected]> Co-authored-by: Richard Lundeen <[email protected]>
1 parent 2332258 commit da79bc3

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

pyrit/memory/memory_exporter.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self):
2020
self.export_strategies = {
2121
"json": self.export_to_json,
2222
"csv": self.export_to_csv,
23+
"md": self.export_to_markdown,
2324
# Future formats can be added here
2425
}
2526

@@ -94,3 +95,27 @@ def export_to_csv(self, data: list[PromptRequestPiece], file_path: Path = None)
9495
writer = csv.DictWriter(f, fieldnames=fieldnames)
9596
writer.writeheader()
9697
writer.writerows(export_data)
98+
99+
def export_to_markdown(self, data: list[PromptRequestPiece], file_path: Path = None) -> None: # type: ignore
100+
"""
101+
Exports the provided data to a Markdown file at the specified file path.
102+
Each item in the data list is converted to a dictionary and formatted as a table.
103+
104+
Args:
105+
data (list[PromptRequestPiece]): The data to be exported, as a list of PromptRequestPiece instances.
106+
file_path (Path): The full path, including the file name, where the data will be exported.
107+
108+
Raises:
109+
ValueError: If no file_path is provided or if there is no data to export.
110+
"""
111+
if not file_path:
112+
raise ValueError("Please provide a valid file path for exporting data.")
113+
if not data:
114+
raise ValueError("No data to export.")
115+
export_data = [piece.to_dict() for piece in data]
116+
fieldnames = list(export_data[0].keys())
117+
with open(file_path, "w", newline="") as f:
118+
f.write(f"| {' | '.join(fieldnames)} |\n")
119+
f.write(f"| {' | '.join(['---'] * len(fieldnames))} |\n")
120+
for row in export_data:
121+
f.write(f"| {' | '.join(str(row[field]) for field in fieldnames)} |\n")

tests/unit/memory/test_memory_exporter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,13 @@ def test_export_to_json_data_with_conversations(tmp_path, export_type):
9090
assert content[1]["role"] == "assistant"
9191
assert content[1]["converted_value"] == "I'm fine, thank you!"
9292
assert content[1]["conversation_id"] == conversation_id
93+
94+
95+
@pytest.mark.parametrize("export_type", ["json", "csv", "md"])
96+
def test_export_data_creates_file(tmp_path, export_type):
97+
exporter = MemoryExporter()
98+
file_path = tmp_path / f"conversations.{export_type}"
99+
sample_conversation_entries = get_sample_conversations()
100+
exporter.export_data(data=sample_conversation_entries, file_path=file_path, export_type=export_type)
101+
102+
assert file_path.exists()

0 commit comments

Comments
 (0)