/ frontend / src / app / views / classifier / ClassifierPlayground.jsx
ClassifierPlayground.jsx
  1  import { useState, useEffect } from "react";
  2  import {
  3    Box, Button, Card, Chip, Grid, LinearProgress, MenuItem, styled,
  4    TextField, Typography,
  5  } from "@mui/material";
  6  import { Category } from "@mui/icons-material";
  7  import Breadcrumb from "app/components/Breadcrumb";
  8  import useAuth from "app/hooks/useAuth";
  9  import api from "app/utils/api";
 10  import { toast } from "react-toastify";
 11  import { Trans, useTranslation } from "react-i18next";
 12  
 13  const Container = styled("div")(({ theme }) => ({
 14    margin: 10,
 15    [theme.breakpoints.down("sm")]: { margin: 16 },
 16    "& .breadcrumb": { marginBottom: 30, [theme.breakpoints.down("sm")]: { marginBottom: 16 } }
 17  }));
 18  
 19  const ContentBox = styled("div")(({ theme }) => ({
 20    margin: "30px",
 21    [theme.breakpoints.down("sm")]: { margin: "16px" }
 22  }));
 23  
 24  const COLORS = ["#42a5f5", "#66bb6a", "#ffa726", "#ef5350", "#ab47bc", "#26c6da", "#5c6bc0", "#ec407a"];
 25  
 26  export default function ClassifierPlayground() {
 27    const { t } = useTranslation();
 28    const auth = useAuth();
 29    const [sequence, setSequence] = useState("");
 30    const [labelsText, setLabelsText] = useState("");
 31    const [selectedModel, setSelectedModel] = useState("");
 32    const [classifiers, setClassifiers] = useState([]);
 33    const [defaultModel, setDefaultModel] = useState("");
 34    const [results, setResults] = useState(null);
 35    const [loading, setLoading] = useState(false);
 36  
 37    useEffect(() => {
 38      document.title = (process.env.REACT_APP_RESTAI_NAME || "RESTai") + " - " + t("classifier.title");
 39      api.get("/tools/classifiers", auth.user.token)
 40        .then((data) => {
 41          setClassifiers(data.classifiers || []);
 42          setDefaultModel(data.default || "");
 43          setSelectedModel(data.default || "");
 44        })
 45        .catch(() => {});
 46      // eslint-disable-next-line react-hooks/exhaustive-deps
 47    }, [t]);
 48  
 49    const handleClassify = () => {
 50      const labels = labelsText.split(",").map((l) => l.trim()).filter(Boolean);
 51      if (!sequence.trim() || labels.length === 0) {
 52        toast.warning(t("classifier.enterBoth"));
 53        return;
 54      }
 55  
 56      setLoading(true);
 57      setResults(null);
 58      const body = { sequence: sequence.trim(), labels };
 59      if (selectedModel && selectedModel !== defaultModel) body.model = selectedModel;
 60      api.post("/tools/classifier", body, auth.user.token)
 61        .then((data) => setResults(data))
 62        .catch(() => {})
 63        .finally(() => setLoading(false));
 64    };
 65  
 66    return (
 67      <Container>
 68        <Box className="breadcrumb">
 69          <Breadcrumb routeSegments={[{ name: t("classifier.breadcrumb"), path: "/classifier" }]} />
 70        </Box>
 71  
 72        <ContentBox>
 73          <Grid container spacing={3}>
 74            {/* Input */}
 75            <Grid item xs={12} md={6}>
 76              <Card elevation={1} sx={{ p: 3 }}>
 77                <Typography variant="subtitle1" fontWeight={600} sx={{ mb: 2, display: "flex", alignItems: "center", gap: 1 }}>
 78                  <Category fontSize="small" /> {t("classifier.playgroundTitle")}
 79                </Typography>
 80  
 81                <TextField
 82                  fullWidth
 83                  select
 84                  label={t("classifier.model")}
 85                  value={selectedModel}
 86                  onChange={(e) => setSelectedModel(e.target.value)}
 87                  sx={{ mb: 2 }}
 88                  helperText={t("classifier.modelHelp")}
 89                >
 90                  {classifiers.map((c) => (
 91                    <MenuItem key={c.id} value={c.id}>
 92                      {c.name}
 93                    </MenuItem>
 94                  ))}
 95                </TextField>
 96  
 97                <TextField
 98                  fullWidth
 99                  multiline
100                  rows={4}
101                  label={t("classifier.textLabel")}
102                  placeholder={t("classifier.textPlaceholder")}
103                  value={sequence}
104                  onChange={(e) => setSequence(e.target.value)}
105                  sx={{ mb: 2 }}
106                />
107  
108                <TextField
109                  fullWidth
110                  label={t("classifier.labels")}
111                  placeholder={t("classifier.labelsPlaceholder")}
112                  helperText={t("classifier.labelsHelp")}
113                  value={labelsText}
114                  onChange={(e) => setLabelsText(e.target.value)}
115                  onKeyDown={(e) => { if (e.key === "Enter") handleClassify(); }}
116                  sx={{ mb: 2 }}
117                />
118  
119                <Button
120                  variant="contained"
121                  onClick={handleClassify}
122                  disabled={loading || !sequence.trim() || !labelsText.trim()}
123                  fullWidth
124                >
125                  {loading ? t("classifier.classifying") : t("classifier.classify")}
126                </Button>
127              </Card>
128            </Grid>
129  
130            {/* Results */}
131            <Grid item xs={12} md={6}>
132              <Card elevation={1} sx={{ p: 3 }}>
133                <Typography variant="subtitle1" fontWeight={600} sx={{ mb: 2 }}>
134                  {t("classifier.results")}
135                </Typography>
136  
137                {loading && <LinearProgress sx={{ mb: 2 }} />}
138  
139                {!results && !loading && (
140                  <Box sx={{ textAlign: "center", py: 6, color: "text.secondary" }}>
141                    <Category sx={{ fontSize: 48, opacity: 0.2, mb: 1 }} />
142                    <Typography variant="body2">{t("classifier.inputHint")}</Typography>
143                  </Box>
144                )}
145  
146                {results && (
147                  <Box>
148                    <Box sx={{ mb: 3, p: 2, bgcolor: "action.hover", borderRadius: 1 }}>
149                      <Typography variant="caption" color="text.secondary">{t("classifier.input")}</Typography>
150                      <Typography variant="body2" sx={{ fontStyle: "italic" }}>
151                        {results.sequence}
152                      </Typography>
153                      {results.model && (
154                        <Typography variant="caption" color="text.secondary" display="block" sx={{ mt: 0.5 }}>
155                          {t("classifier.modelLabel", { model: results.model })}
156                        </Typography>
157                      )}
158                    </Box>
159  
160                    <Box sx={{ display: "flex", flexDirection: "column", gap: 1.5 }}>
161                      {results.labels.map((label, i) => {
162                        const score = results.scores[i];
163                        const pct = (score * 100).toFixed(1);
164                        return (
165                          <Box key={label}>
166                            <Box sx={{ display: "flex", justifyContent: "space-between", alignItems: "center", mb: 0.5 }}>
167                              <Chip
168                                label={label}
169                                size="small"
170                                sx={{
171                                  bgcolor: COLORS[i % COLORS.length] + "20",
172                                  color: COLORS[i % COLORS.length],
173                                  fontWeight: i === 0 ? 700 : 400,
174                                  border: i === 0 ? `2px solid ${COLORS[i % COLORS.length]}` : "none",
175                                }}
176                              />
177                              <Typography variant="body2" fontWeight={i === 0 ? 700 : 400}>
178                                {pct}%
179                              </Typography>
180                            </Box>
181                            <LinearProgress
182                              variant="determinate"
183                              value={score * 100}
184                              sx={{
185                                height: 8,
186                                borderRadius: 4,
187                                bgcolor: "action.hover",
188                                "& .MuiLinearProgress-bar": {
189                                  bgcolor: COLORS[i % COLORS.length],
190                                  borderRadius: 4,
191                                },
192                              }}
193                            />
194                          </Box>
195                        );
196                      })}
197                    </Box>
198  
199                    {results.labels.length > 0 && (
200                      <Box sx={{ mt: 3, p: 2, bgcolor: COLORS[0] + "10", borderRadius: 1, border: `1px solid ${COLORS[0]}30` }}>
201                        <Typography variant="body2" color="text.secondary">
202                          <Trans
203                            i18nKey="classifier.bestMatch"
204                            values={{ label: results.labels[0], pct: (results.scores[0] * 100).toFixed(1) }}
205                            components={{ strong: <strong style={{ color: COLORS[0] }} /> }}
206                          />
207                        </Typography>
208                      </Box>
209                    )}
210                  </Box>
211                )}
212              </Card>
213            </Grid>
214          </Grid>
215        </ContentBox>
216      </Container>
217    );
218  }