@@ -503,6 +503,40 @@ def extract_html_text(cls, offline: bool, input: str, html_path: str):
503503 f .close ()
504504 return domains , content
505505
506+ @classmethod
507+ def _predict_with_model (cls , model , input_data ):
508+ """
509+ Universal prediction method that handles both Keras models and TFSMLayer.
510+
511+ Args:
512+ model: Either a Keras model or TFSMLayer instance
513+ input_data: Input tensor for prediction
514+
515+ Returns:
516+ Prediction results
517+ """
518+ tf = _get_tensorflow ()
519+ if hasattr (model , 'predict' ):
520+ return model .predict (input_data )
521+ else :
522+ # Handle TFSMLayer which is callable but doesn't have predict method
523+ # TFSMLayer expects keyword arguments and may return a dict
524+ if isinstance (input_data , (list , str )):
525+ # For text input (list of strings), convert to tensor and pass as keyword argument
526+ tf = _get_tensorflow ()
527+ input_tensor = tf .constant (input_data )
528+ results = model (input_tensor )
529+ else :
530+ # For tensor input (images), pass as positional argument
531+ results = model (input_data )
532+
533+ # TFSMLayer may return a dict, extract the tensor if needed
534+ if isinstance (results , dict ):
535+ # Get the first (and likely only) tensor value from the dict
536+ results = list (results .values ())[0 ]
537+
538+ return results
539+
506540 @classmethod
507541 def load_model (cls , model_file_name : str , latest : bool = False ):
508542 """
@@ -706,14 +740,14 @@ def pred_shalla_cat_with_text(
706740 all_results = []
707741 for i in range (0 , len (content ), config .batch_size ):
708742 batch_content = content [i :i + config .batch_size ]
709- batch_results = cls .model . predict ( batch_content )
743+ batch_results = cls ._predict_with_model ( cls . model , batch_content )
710744 all_results .append (batch_results )
711745 # Clear intermediate results to free memory
712746 del batch_results
713747 results = np .concatenate (all_results , axis = 0 )
714748 del all_results # Free memory
715749 else :
716- results = cls .model . predict ( content )
750+ results = cls ._predict_with_model ( cls . model , content )
717751
718752 tf = _get_tensorflow ()
719753 probs = tf .nn .softmax (results )
@@ -808,7 +842,11 @@ def pred_shalla_cat_with_images(
808842
809843 logger .info ("Processing image tensors" )
810844 # Extract domains for file lookups
811- domain_list = [cls .parse_url_to_domain (item ) for item in urls_or_domains ]
845+ if offline_images and len (urls_or_domains ) == 0 :
846+ # In offline mode with no input, process all images in directory
847+ domain_list = []
848+ else :
849+ domain_list = [cls .parse_url_to_domain (item ) for item in urls_or_domains ]
812850 images = cls .extract_image_tensor (offline_images , domain_list , image_path )
813851 img_domains = list (images .keys ())
814852 logger .info (f"Successfully processed images for { len (img_domains )} domains" )
@@ -833,7 +871,7 @@ def pred_shalla_cat_with_images(
833871 for i in range (0 , len (img_tensors_list ), config .batch_size ):
834872 batch_tensors = img_tensors_list [i :i + config .batch_size ]
835873 batch_tensor_stack = tf .stack (batch_tensors )
836- batch_results = cls .model_cv . predict ( batch_tensor_stack )
874+ batch_results = cls ._predict_with_model ( cls . model_cv , batch_tensor_stack )
837875 all_results .append (batch_results )
838876
839877 # Clear intermediate tensors to free memory
@@ -844,7 +882,7 @@ def pred_shalla_cat_with_images(
844882 else :
845883 tf = _get_tensorflow ()
846884 img_tensors = tf .stack (img_tensors_list )
847- results = cls .model_cv . predict ( img_tensors )
885+ results = cls ._predict_with_model ( cls . model_cv , img_tensors )
848886 del img_tensors # Free memory
849887
850888 # Clear the images dict to free memory
0 commit comments