|
1 | 1 | import os |
| 2 | +import re |
2 | 3 | from setuptools import setup, Distribution |
3 | 4 | import shutil |
4 | 5 |
|
@@ -32,31 +33,10 @@ for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): |
32 | 33 | ''') |
33 | 34 | f.write('\n\n'.join(api_content)) |
34 | 35 |
|
35 | | -def write_version_py(filename='python/paddle_custom_device/intel_hpu/__init__.py', source_path='python/paddle_custom_device/intel_hpu'): |
36 | | - dirname = os.path.dirname(filename) |
37 | | - if not os.path.exists(dirname): |
38 | | - os.makedirs(dirname) |
39 | | - |
40 | | - # to get all the .py file under source_path (exclude __init__.py file) |
41 | | - py_files = [] |
42 | | - if os.path.exists(source_path): |
43 | | - for file in os.listdir(source_path): |
44 | | - if file.endswith('.py') and file != '__init__.py': |
45 | | - py_files.append(file[:-3]) # to remove the .py |
46 | | - |
47 | | - # generate the import line |
48 | | - import_statements = [] |
49 | | - for module in sorted(py_files): # to do sort |
50 | | - import_statements.append(f"from .{module} import *") |
51 | 36 |
|
52 | | - cnt = '''# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY |
53 | | -# |
54 | | -''' |
55 | | - #to add all import lines |
56 | | - cnt += '\n'.join(import_statements) + '\n\n' |
57 | | - |
58 | | - #to add version info and function |
59 | | - cnt += '''full_version = '@PADDLE_VERSION@' |
| 37 | +# Common version information template |
| 38 | +VERSION_TEMPLATE = ''' |
| 39 | +full_version = '@PADDLE_VERSION@' |
60 | 40 | synapse_version = @SYNAPSE_VERSION@ |
61 | 41 | git_commit_id = '@GIT_HASH@' |
62 | 42 | custom_op_git_commit_id = '@CUSTOM_OP_GIT_HASH@' |
@@ -88,10 +68,90 @@ def version(): |
88 | 68 | return {'version': full_version, 'commit': git_commit_id, 'custom_op commit': custom_op_git_commit_id, 'synapse': synapse_version} |
89 | 69 | ''' |
90 | 70 |
|
| 71 | +# Common header template |
| 72 | +HEADER_TEMPLATE = '''# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY |
| 73 | +# |
| 74 | +from .ops import * |
| 75 | +''' |
| 76 | +def write_default_version_py(filename='python/paddle_custom_device/intel_hpu/__init__.py'): |
| 77 | + dirname = os.path.dirname(filename) |
| 78 | + if not os.path.exists(dirname): |
| 79 | + os.makedirs(dirname) |
| 80 | + cnt = HEADER_TEMPLATE + VERSION_TEMPLATE |
| 81 | + |
91 | 82 | with open(filename, 'w') as f: |
92 | 83 | f.write(cnt) |
| 84 | + print(f"Generated default version file: {filename}") |
| 85 | + |
| 86 | +def write_version_py(filename='python/paddle_custom_device/intel_hpu/__init__.py'): |
| 87 | + """ |
| 88 | + Rewrite __init__.py file with the following modifications: |
| 89 | + 1. Remove lines containing 'paddle_custom_device.intel_hpu.ops import' |
| 90 | + 2. Insert generated header before the first 'from xxx import' statement |
| 91 | + 3. Add version information and functions at the end of the file |
| 92 | + """ |
| 93 | + |
| 94 | + if not os.path.exists(filename): |
| 95 | + print(f"File {filename} does not exist, to use defual write_version") |
| 96 | + write_default_version_py() |
| 97 | + return |
| 98 | + |
| 99 | + # Read original file content |
| 100 | + with open(filename, 'r', encoding='utf-8') as f: |
| 101 | + lines = f.readlines() |
| 102 | + |
| 103 | + # Filter out lines containing paddle_custom_device.intel_hpu.ops import |
| 104 | + filtered_lines = [] |
| 105 | + for line in lines: |
| 106 | + if 'paddle_custom_device.intel_hpu.ops import' not in line: |
| 107 | + filtered_lines.append(line) |
| 108 | + |
| 109 | + # Find the position of the first 'from xxx import' statement |
| 110 | + first_import_index = -1 |
| 111 | + for i, line in enumerate(filtered_lines): |
| 112 | + stripped_line = line.strip() |
| 113 | + # Match patterns like 'from .xxx import *' or 'from xxx import *' |
| 114 | + if re.match(r'^from\s+(\.\w+|\w+)\s+import\s+\*', stripped_line): |
| 115 | + first_import_index = i |
| 116 | + break |
| 117 | + |
| 118 | + # Prepare header to be inserted |
| 119 | + header = HEADER_TEMPLATE |
| 120 | + |
| 121 | + # Build new file content |
| 122 | + new_content = [] |
| 123 | + |
| 124 | + # Add copyright header and other content before the first import statement |
| 125 | + if first_import_index != -1: |
| 126 | + # Add all lines from beginning to just before the first import |
| 127 | + new_content.extend(filtered_lines[:first_import_index]) |
| 128 | + # Insert generated header before the first import |
| 129 | + new_content.append(header) |
| 130 | + # Add remaining lines (including the first import) |
| 131 | + new_content.extend(filtered_lines[first_import_index:]) |
| 132 | + else: |
| 133 | + # If no import statements found, add all filtered lines and append header at the end |
| 134 | + new_content.extend(filtered_lines) |
| 135 | + new_content.append('\n' + header) |
| 136 | + |
| 137 | + # Add version information at the end of file |
| 138 | + footer = VERSION_TEMPLATE |
| 139 | + |
| 140 | + new_content.append(footer) |
| 141 | + |
| 142 | + # Write modified content back to file |
| 143 | + with open(filename, 'w', encoding='utf-8') as f: |
| 144 | + f.writelines(new_content) |
| 145 | + |
| 146 | + print(f"Successfully rewritten {filename}") |
| 147 | + |
| 148 | + # Display processing results |
| 149 | + print("\nProcessed content preview:") |
| 150 | + with open(filename, 'r', encoding='utf-8') as f: |
| 151 | + content = f.read() |
| 152 | + print(content) |
93 | 153 |
|
94 | | - print(f"Generated {filename} with imports from {len(py_files)} modules: {py_files}") |
| 154 | + print(f"Generated {filename} done") |
95 | 155 |
|
96 | 156 | def write_init_py(filename='python/paddle_custom_device/__init__.py'): |
97 | 157 | dirname = os.path.dirname(filename) |
@@ -124,7 +184,8 @@ def copy_paddlenlp_ops_files(source_path='../custom_ops/python/paddlenlp_ops', t |
124 | 184 | return |
125 | 185 |
|
126 | 186 | for file in os.listdir(source_path): |
127 | | - if file.endswith('.py') and file != '__init__.py': |
| 187 | + #all .py file including the '__init__.py' file will be copied into target_path |
| 188 | + if file.endswith('.py'): |
128 | 189 | source_file = os.path.join(source_path, file) |
129 | 190 |
|
130 | 191 | #to check file can be read |
|
0 commit comments