// src/utils/ReinforcementLearning.js
import * as tf from "@tensorflow/tfjs";

const createModel = () => {
  const model = tf.sequential();
  model.add(tf.layers.dense({ inputShape: [6], units: 24, activation: "relu" })); // Adjusted input shape to 6
  model.add(tf.layers.dense({ units: 24, activation: "relu" }));
  model.add(tf.layers.dense({ units: 2 })); // Output: [lookDirection, speed]

  model.compile({
    optimizer: "adam",
    loss: "meanSquaredError",
  });

  return model;
};

const trainModel = async (model, inputs, targets) => {
  const xs = tf.tensor2d(inputs);
  const ys = tf.tensor2d(targets);
  await model.fit(xs, ys, { epochs: 10 });
};

const predictAction = (model, state) => {
  const input = tf.tensor2d([state], [1, 6]); // Update to match the 4 features
  const prediction = model.predict(input);
  const action = prediction.arraySync()[0];
  return action;
};

export { createModel, trainModel, predictAction };
