44# SPDX-License-Identifier: Apache-2.0
55#
66
7+ import os
78import io
89import json
910import logging
@@ -73,38 +74,28 @@ def forward(self, filename: Path, crate: Crate) -> dspy.Prediction:
7374 # Translate symbol by symbol
7475 translations : dict [str , str ] = OrderedDict ()
7576 for symbol_name , symbol_code in sources .items ():
76- ref_names = [
77- name for name in bfs (symbol_name , dependencies ) if name in translations
78- ]
7977 dep_names = [
8078 name for name in bfs (symbol_name , references , max_depth = 1 ) if name in sources
8179 ]
82- # move dependent names that has a translation to reference names
83- for dep_name in dep_names :
84- assert dep_name not in translations , (
85- f"Dependency { dep_name } should not already be translated"
86- )
87- for ref_name in bfs (dep_name , dependencies ):
88- if ref_name in translations and ref_name not in ref_names :
89- ref_names .append (ref_name )
9080
9181 # Gather reference and dependent code in order of translations and sources, respectively
92- ref_translations = "\n \n " .join (
93- [translation for name , translation in translations .items () if name in ref_names ]
94- )
82+ ref_translations = "\n \n " .join (translations .values ())
9583 dep_sources = "\n \n " .join (
9684 [source for name , source in sources .items () if name in dep_names ]
9785 )
9886
9987 logger .info (f"Translating `{ symbol_name } ` ..." )
10088 logger .debug (f"```c\n { symbol_code } \n ```" )
10189
90+ # FIXME: Pass a Symbol here instead of symbol_code and is_snippet_main. Similarly,
91+ # dep_sources should probably be a list[Symbol] too.
10292 pred = self .translate_with_feedback (
10393 ref_translations ,
10494 symbol_code ,
10595 dep_sources ,
10696 crate ,
10797 max_iters = self .max_iters ,
98+ is_snippet_main = symbol_name == "c:@F@main" ,
10899 )
109100 # pred = dspy.Prediction(translation=dspy.Code(code=""))
110101
@@ -115,16 +106,22 @@ def forward(self, filename: Path, crate: Crate) -> dspy.Prediction:
115106 # Update state
116107 translations [symbol_name ] = pred .translation .code
117108 with crate .rust_src_path .with_suffix (".jsonl" ).open ("a" ) as f :
118- f . write (
119- json .dumps (
109+ for prior_translation , feedback in zip ( pred . prior_translations , pred . feedbacks ):
110+ jsonl = json .dumps (
120111 {
121112 "name" : symbol_name ,
122- "source" : symbol_code ,
113+ "reference_names" : list (translations .keys ()),
114+ "reference_code" : ref_translations ,
115+ "snippet" : symbol_code ,
116+ "dependent_names" : dep_names ,
117+ "dependent_code" : dep_sources ,
118+ "prior_translation" : prior_translation ,
119+ "feedback" : feedback ,
123120 "translation" : pred .translation .code ,
121+ "success" : pred .success ,
124122 }
125123 )
126- + "\n "
127- )
124+ f .write (jsonl + "\n " )
128125
129126 translation = "\n \n " .join (translations .values ())
130127 return dspy .Prediction (translation = translation )
@@ -146,8 +143,6 @@ class TranslateSignature(dspy.Signature):
146143 Use the `cargo build` feedback about the prior_translation, if provided, when generating the Rust translation.
147144 """
148145
149- # For example, reason about how a Rust translation of the dependent_code would inform a safe and idiomatic translation of the C snippet.
150-
151146 reference_code : dspy .Code ["Rust" ] = dspy .InputField () # noqa: F821
152147 snippet : dspy .Code ["C" ] = dspy .InputField () # noqa: F821
153148 dependent_code : dspy .Code ["C" ] = dspy .InputField () # noqa: F821
@@ -185,46 +180,54 @@ def translate_with_feedback(
185180 crate : Crate ,
186181 * ,
187182 max_iters : int = 0 ,
183+ is_snippet_main : bool = False ,
188184 ) -> dspy .Prediction :
189185 pred = self .translate (reference_code , snippet , dependent_code )
190- i = 0
191- for i in range (max_iters ):
186+ success , prior_translations , feedbacks = False , [ "" ], [ "" ]
187+ for _ in range (max_iters ):
192188 rust_src = ""
193189 if len (reference_code ) > 0 :
194190 rust_src += reference_code + "\n \n "
195191 rust_src += pred .translation .code + "\n \n "
196- if crate .is_bin and "fn main()" not in rust_src :
192+ if crate .is_bin and not is_snippet_main :
197193 # Work around E0601 error
198194 rust_src += 'fn main() {\n println!("Hello, world!");\n }\n '
199195
200196 crate .rust_src_path .write_text (rust_src )
197+ env = os .environ .copy ()
198+ env ["RUSTFLAGS" ] = (env .get ("RUSTFLAGS" , "" ) + " -D unsafe-code" ).strip ()
201199 success , feedback = tools .run_subprocess (
202200 [
203201 "cargo" ,
204202 "build" ,
205203 "--quiet" ,
206204 "--color=never" ,
207205 f"--manifest-path={ crate .cargo_toml } " ,
208- ]
206+ ],
207+ env = env ,
209208 )
210209 if success :
211210 break
212211 logger .debug (
213212 f"Feedback\n ```rust\n { reference_code } \n { pred .translation .code } \n ```\n \n # Feedback\n { feedback } \n \n # reasoning\n { pred .reasoning } "
214213 )
215214
215+ feedbacks .append (feedback )
216+ prior_translations .append (pred .translation .code )
216217 pred = self .translate (
217218 reference_code ,
218219 snippet ,
219220 dependent_code ,
220- prior_translation = pred .translation ,
221+ prior_translation = pred .translation . code ,
221222 feedback = feedback ,
222223 )
223224 else :
224225 logger .warning (
225226 f"Translation failed to build after { max_iters } feedback iterations!"
226227 )
227- pred ["iters" ] = i
228+ pred ["feedbacks" ] = feedbacks
229+ pred ["prior_translations" ] = prior_translations
230+ pred ["success" ] = success
228231 return pred
229232
230233 def get_history (self , n : int = 1 , clear : bool = False ) -> str :
0 commit comments