Как связать табличные и текстовые данные с пакетом ktrain?

Я пытался последовать этому примеру . До этого я использовал учебник для табличных и классификации учебник длятекстовых данных, которые работали нормально. Теперь я пытаюсь объединить модели. Я не получаю никаких ошибок до самого последнего отмеченного блока. Я там что-то не так делаю, но не могу куда. Буду признателен за любые подсказки.

      #split data into test and train subsets
#split the data into text and train subsets
X_train, X_test, y_train, y_test ,X_trainText, X_testText= train_test_split(df_analysis, Y,textFeatures,
                                                                            test_size=0.2, random_state=0,stratify=Y)
   
####### prepare text model 
#preprocess text
Y_trainText=list(y_train.astype("int64"))
Y_testTest=list(y_test.astype("int64"))
trn1, val1, preproc1 = text.texts_from_array(x_train=list(X_trainText), y_train=Y_trainText,  
                                          x_test=list(X_testText), y_test=Y_testTest,
                                          class_names=["0","1"],
                                          preprocess_mode='distilbert',
                                          maxlen=100)
text.print_text_classifiers()
model2 = text.text_classifier('distilbert', train_data=trn1, preproc=preproc1)

#####prepare tabular model
test=pd.concat([X_train, y_train], axis=1)

trn, val, preproc = tabular.tabular_from_df(test, label_columns=['label'], random_state=42)
tabular.print_tabular_classifiers()
model = tabular.tabular_classifier('mlp', trn)


extra_input = keras.layers.Input(shape=(63,))
model.call(extra_input)
for i in model.layers:
        print(i.output)
        

####prepare text model
Y_trainText=list(y_train.astype("int64"))
Y_testTest=list(y_test.astype("int64"))
trn1, val1, preproc1 = text.texts_from_array(x_train=list(X_trainText), y_train=Y_trainText,
                                          x_test=list(X_testText), y_test=Y_testTest,
                                          class_names=["0","1"],
                                          preprocess_mode='distilbert',
                                          maxlen=100)
text.print_text_classifiers()
model2 = text.text_classifier('distilbert', train_data=trn1, preproc=preproc1)



#concatenate models
import tensorflow as tf
from ktrain.data import TFDataset
BATCH_SIZE = 256


trn_combined =  [trn] +  [trn1[0]] + [trn1[1]]
val_combined =  [val] +  [val1[0]] + [val1[1]]


def features_to_tfdataset(examples):

    def gen():
        for idx, ex0 in enumerate(examples[0]):
            ex1 = examples[1][idx]
            label = examples[2][idx]
            x = (ex0, ex1)
            y = label
            yield ( (x, y) )

    tfdataset= tf.data.Dataset.from_generator(gen,
            ((tf.int32, tf.int32), tf.int64),
            ((tf.TensorShape([None]), tf.TensorShape([None])), tf.TensorShape([])) )
    return tfdataset

train_tfdataset= features_to_tfdataset(trn_combined)
val_tfdataset= features_to_tfdataset(val_combined)
train_tfdataset = train_tfdataset.batch(BATCH_SIZE)
val_tfdataset = val_tfdataset.batch(BATCH_SIZE)


##########this part is not working
from tensorflow import keras
extra_input = keras.layers.Input(shape=(1,))
extra_output = model.output
extra_output = keras.layers.Flatten()(extra_output)
extra_model = keras.Model(inputs=extra_input, outputs=extra_output)
extra_model.compile(loss='mse', optimizer='adam', metrics=['mae'])
# Combine tabular module with text model
merged_out = keras.layers.concatenate([model.output, model2.output])
merged_out = keras.layers.Dropout(0.25)(merged_out)
merged_out = keras.layers.Dense(1000, activation='relu')(merged_out)
merged_out = keras.layers.Dropout(0.25)(merged_out)
merged_out = keras.layers.Dense(500, activation='relu')(merged_out)
merged_out = keras.layers.Dropout(0.5)(merged_out)
merged_out = keras.layers.Dense(1)(merged_out)
combined_model = keras.Model([model.input] + [mode23.input], merged_out)
combined_model.compile(loss='mae',
                       optimizer='adam',
                      metrics=['mae'])

0 ответов

Другие вопросы по тегам